diff --git a/bob/paper/nir_patch_pooling/config/patch_pooling_lr.py b/bob/paper/nir_patch_pooling/config/patch_pooling_lr.py index 17f63564bb7146e5955681c4f6e7f0edee53dcd8..192e9d29c69732d2a8c0add2b13504e662010f70 100644 --- a/bob/paper/nir_patch_pooling/config/patch_pooling_lr.py +++ b/bob/paper/nir_patch_pooling/config/patch_pooling_lr.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # -*- coding: utf-8 -*- """ @@ -6,7 +5,7 @@ Configuration file to run PatchPooling + LR classifier for Face PAD toward detection of mask attacks in NIR. """ -#---------------------------------------------------------- +#------------------------------------------------------------------------------ sub_directory = "pooling_lr" @@ -17,6 +16,7 @@ from bob.pad.face.preprocessor import FaceCropAlign from bob.bio.video.preprocessor import Wrapper from bob.bio.video.utils import FrameSelector + # parameters and constants FACE_SIZE = 128 RGB_OUTPUT_FLAG = False @@ -37,9 +37,10 @@ _image_preprocessor = FaceCropAlign(face_size=FACE_SIZE, _frame_selector = FrameSelector(selection_style = "all") -preprocessor = Wrapper(preprocessor = _image_preprocessor, frame_selector = _frame_selector) +preprocessor = Wrapper(preprocessor = _image_preprocessor, + frame_selector = _frame_selector) -#---------------------------------------------------------- +#------------------------------------------------------------------------------ # define extractor: @@ -48,16 +49,19 @@ from bob.bio.video.extractor import Wrapper from bob.extension import rc import os -_model_dir = rc.get("LIGHTCNN9_MODEL_DIRECTORY") +_model_directory = rc["lightcnn9.model.directory"] _model_name = "LightCNN_9Layers_checkpoint.pth.tar" -_model_file = os.path.join(_model_dir, _model_name) +_model_file = os.path.join(_model_directory, _model_name) + if not os.path.exists(_model_file): - print("Error: Could not find the LightCNN-9 model at [{}].\nPlease follow the download instructions from README".format(_model_dir)) + print("Error: Could not find the LightCNN-9 model [{}].\nPlease follow \ + the download instructions from README".format(_model_directory)) exit(0) -extractor = Wrapper(PatchPoolingCNN(model_file=_model_file), frame_selector = _frame_selector) +extractor = Wrapper(PatchPoolingCNN(model_file=_model_file), + frame_selector = _frame_selector) -#---------------------------------------------------------- +#------------------------------------------------------------------------------ # define algorithm @@ -67,12 +71,6 @@ C = 1.0 algorithm = LogRegr(C=C, frame_level_scores_flag=True) -#---------------------------------------------------------- - - - - - - +#------------------------------------------------------------------------------ diff --git a/bob/paper/nir_patch_pooling/extractor/patch_pooling_cnn.py b/bob/paper/nir_patch_pooling/extractor/patch_pooling_cnn.py index 5941c877c73d2fba8b661d8e241f5bfefb925b72..e98cb5a17177cd79f3c61f769a06c146f1d01e3d 100644 --- a/bob/paper/nir_patch_pooling/extractor/patch_pooling_cnn.py +++ b/bob/paper/nir_patch_pooling/extractor/patch_pooling_cnn.py @@ -1,10 +1,8 @@ -#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ -Implementation of PCNN feature extractor for LightCNN-9. +Implementation of Patch Pooling CNN feature extractor with LightCNN-9 backbone @author: Ketan Kotwal - """ # Imports @@ -21,14 +19,13 @@ import logging logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -#---------------------------------------------------------- +#------------------------------------------------------------------------------ class PatchPoolingCNN(Extractor): """ - The class implements the feature extraction of LightCNN9 embeddings. - It has some implementation differences from a similar extractor from - bob.learn.pytorch. + The class implements extraction of patch pooled features from the final + convolutional layer of LightCNN9 (MFM5 layer). """ def __init__(self, model_file=None, num_classes=79077): @@ -39,30 +36,36 @@ class PatchPoolingCNN(Extractor): # load the model into network. cp = torch.load(model_file, map_location="cpu") - # checked if pre-trained model was saved using nn.DataParallel ... - saved_with_nnDataParallel = False + # checked if pre-trained model was saved using nn.DataParallel + saved_with_data_parallel = False for k, v in cp["state_dict"].items(): if("module" in k): - saved_with_nnDataParallel = True + saved_with_data_parallel = True break # if DataParallel format, remove module term - if(saved_with_nnDataParallel): + if(saved_with_data_parallel): if("state_dict" in cp): + from collections import OrderedDict new_state_dict = OrderedDict() + for k, v in cp["state_dict"].items(): name = k[7:] new_state_dict[name] = v + self.network.load_state_dict(new_state_dict) else: + self.network.load_state_dict(cp["state_dict"]) + self.network.eval() # image pre-processing - self.data_transform = transforms.Compose([transforms.Resize(size=128), transforms.ToTensor()]) + self.data_transform = transforms.Compose([transforms.Resize(size=128), + transforms.ToTensor()]) -#---------------------------------------------------------- +#------------------------------------------------------------------------------ def __call__(self, image): @@ -76,7 +79,7 @@ class PatchPoolingCNN(Extractor): Returns ------- feature : :py:class:`numpy.ndarray` (floats) - The extracted features as a 1d array of size 320 + The extracted features as a 1d array of size 256 """ @@ -84,45 +87,49 @@ class PatchPoolingCNN(Extractor): pil_image = Image.fromarray(image.astype(np.uint8)) input_image = self.data_transform(pil_image) input_image = input_image.unsqueeze(0) - - # to be compliant with the loaded model, where weight and biases are torch.FloatTensor input_image = input_image.float() + # obtain the features (to be pooled) from forward pass of network _ , features = self.network.forward(Variable(input_image)) + + # pool features through patch-level processing features = self.conv_to_patch(features) features = features.data.numpy().flatten() return features.astype(np.float64) -#---------------------------------------------------------- +#------------------------------------------------------------------------------ def conv_to_patch(self, features): - logger.debug("Shape of input features: {} {}".format(features.shape, features.squeeze().shape)) - # parameters for the patch conversion - stride = 4 # orig:4 #feat.shape[2]/4 - idx = 0 # purely for debugging - num_patch = features.shape[2]/stride - feat_patch = torch.zeros(1, stride*stride*features.shape[1]) + stride = 4 # features.shape[2]/4 + + # for debugging + # idx = 0 + # num_patch = features.shape[2]/stride + + pooled_features = torch.zeros(1, stride*stride*features.shape[1]) # obtain patches by tesselation of feature maps + # pool linearized version of individual patches for i in range(0, features.shape[2], stride): for j in range(0, features.shape[3], stride): - feat_tmp = features[:, :, i:i+stride, j:j+stride] - feat_tmp = feat_tmp.contiguous().view(feat_tmp.size(0), -1) - feat_patch += feat_tmp - idx += 1 + feat_temp = features[:, :, i:i+stride, j:j+stride] + feat_temp = feat_temp.contiguous().view(feat_temp.size(0), -1) + pooled_features += feat_temp + # idx += 1 - # normalize the patch vector - feat_patch = feat_patch/stride/stride - logger.debug("Feat patch shape: {}".format(feat_patch.shape)) + # normalize the vector of pooled features + pooled_features = pooled_features/stride/stride + + return pooled_features - return feat_patch +#------------------------------------------------------------------------------ -#---------------------------------------------------------- +#------------------------------------------------------------------------------ -# class LightCNN9Patch: it inherits the LightCNN-9 class, and returns the last -# conv layer features; instead of embeddings. +# class LightCNN9Patch: it inherits the LightCNN-9 class from bob, +# and returns the last conv layer features and embeddings. class LightCNN9Patch(LightCNN9): @@ -131,7 +138,7 @@ class LightCNN9Patch(LightCNN9): # do not change the init super(LightCNN9Patch, self).__init__() -#---------------------------------------------------------- +#------------------------------------------------------------------------------ def forward(self, x): @@ -147,14 +154,7 @@ class LightCNN9Patch(LightCNN9): out = self.fc2(x) return out, conv_out -#---------------------------------------------------------- - - - - - - - +#------------------------------------------------------------------------------