Skip to content
Snippets Groups Projects
Commit 805dfa9b authored by Ketan Kotwal's avatar Ketan Kotwal
Browse files

code cleanup: extractor and its config

parent 27e2f462
Branches
No related tags found
No related merge requests found
#!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
...@@ -6,7 +5,7 @@ Configuration file to run PatchPooling + LR classifier for Face PAD ...@@ -6,7 +5,7 @@ Configuration file to run PatchPooling + LR classifier for Face PAD
toward detection of mask attacks in NIR. toward detection of mask attacks in NIR.
""" """
#---------------------------------------------------------- #------------------------------------------------------------------------------
sub_directory = "pooling_lr" sub_directory = "pooling_lr"
...@@ -17,6 +16,7 @@ from bob.pad.face.preprocessor import FaceCropAlign ...@@ -17,6 +16,7 @@ from bob.pad.face.preprocessor import FaceCropAlign
from bob.bio.video.preprocessor import Wrapper from bob.bio.video.preprocessor import Wrapper
from bob.bio.video.utils import FrameSelector from bob.bio.video.utils import FrameSelector
# parameters and constants # parameters and constants
FACE_SIZE = 128 FACE_SIZE = 128
RGB_OUTPUT_FLAG = False RGB_OUTPUT_FLAG = False
...@@ -37,9 +37,10 @@ _image_preprocessor = FaceCropAlign(face_size=FACE_SIZE, ...@@ -37,9 +37,10 @@ _image_preprocessor = FaceCropAlign(face_size=FACE_SIZE,
_frame_selector = FrameSelector(selection_style = "all") _frame_selector = FrameSelector(selection_style = "all")
preprocessor = Wrapper(preprocessor = _image_preprocessor, frame_selector = _frame_selector) preprocessor = Wrapper(preprocessor = _image_preprocessor,
frame_selector = _frame_selector)
#---------------------------------------------------------- #------------------------------------------------------------------------------
# define extractor: # define extractor:
...@@ -48,16 +49,19 @@ from bob.bio.video.extractor import Wrapper ...@@ -48,16 +49,19 @@ from bob.bio.video.extractor import Wrapper
from bob.extension import rc from bob.extension import rc
import os import os
_model_dir = rc.get("LIGHTCNN9_MODEL_DIRECTORY") _model_directory = rc["lightcnn9.model.directory"]
_model_name = "LightCNN_9Layers_checkpoint.pth.tar" _model_name = "LightCNN_9Layers_checkpoint.pth.tar"
_model_file = os.path.join(_model_dir, _model_name) _model_file = os.path.join(_model_directory, _model_name)
if not os.path.exists(_model_file): if not os.path.exists(_model_file):
print("Error: Could not find the LightCNN-9 model at [{}].\nPlease follow the download instructions from README".format(_model_dir)) print("Error: Could not find the LightCNN-9 model [{}].\nPlease follow \
the download instructions from README".format(_model_directory))
exit(0) exit(0)
extractor = Wrapper(PatchPoolingCNN(model_file=_model_file), frame_selector = _frame_selector) extractor = Wrapper(PatchPoolingCNN(model_file=_model_file),
frame_selector = _frame_selector)
#---------------------------------------------------------- #------------------------------------------------------------------------------
# define algorithm # define algorithm
...@@ -67,12 +71,6 @@ C = 1.0 ...@@ -67,12 +71,6 @@ C = 1.0
algorithm = LogRegr(C=C, frame_level_scores_flag=True) algorithm = LogRegr(C=C, frame_level_scores_flag=True)
#---------------------------------------------------------- #------------------------------------------------------------------------------
#!/usr/bin/env python2
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
Implementation of PCNN feature extractor for LightCNN-9. Implementation of Patch Pooling CNN feature extractor with LightCNN-9 backbone
@author: Ketan Kotwal @author: Ketan Kotwal
""" """
# Imports # Imports
...@@ -21,14 +19,13 @@ import logging ...@@ -21,14 +19,13 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
#---------------------------------------------------------- #------------------------------------------------------------------------------
class PatchPoolingCNN(Extractor): class PatchPoolingCNN(Extractor):
""" """
The class implements the feature extraction of LightCNN9 embeddings. The class implements extraction of patch pooled features from the final
It has some implementation differences from a similar extractor from convolutional layer of LightCNN9 (MFM5 layer).
bob.learn.pytorch.
""" """
def __init__(self, model_file=None, num_classes=79077): def __init__(self, model_file=None, num_classes=79077):
...@@ -39,30 +36,36 @@ class PatchPoolingCNN(Extractor): ...@@ -39,30 +36,36 @@ class PatchPoolingCNN(Extractor):
# load the model into network. # load the model into network.
cp = torch.load(model_file, map_location="cpu") cp = torch.load(model_file, map_location="cpu")
# checked if pre-trained model was saved using nn.DataParallel ... # checked if pre-trained model was saved using nn.DataParallel
saved_with_nnDataParallel = False saved_with_data_parallel = False
for k, v in cp["state_dict"].items(): for k, v in cp["state_dict"].items():
if("module" in k): if("module" in k):
saved_with_nnDataParallel = True saved_with_data_parallel = True
break break
# if DataParallel format, remove module term # if DataParallel format, remove module term
if(saved_with_nnDataParallel): if(saved_with_data_parallel):
if("state_dict" in cp): if("state_dict" in cp):
from collections import OrderedDict from collections import OrderedDict
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for k, v in cp["state_dict"].items(): for k, v in cp["state_dict"].items():
name = k[7:] name = k[7:]
new_state_dict[name] = v new_state_dict[name] = v
self.network.load_state_dict(new_state_dict) self.network.load_state_dict(new_state_dict)
else: else:
self.network.load_state_dict(cp["state_dict"]) self.network.load_state_dict(cp["state_dict"])
self.network.eval() self.network.eval()
# image pre-processing # image pre-processing
self.data_transform = transforms.Compose([transforms.Resize(size=128), transforms.ToTensor()]) self.data_transform = transforms.Compose([transforms.Resize(size=128),
transforms.ToTensor()])
#---------------------------------------------------------- #------------------------------------------------------------------------------
def __call__(self, image): def __call__(self, image):
...@@ -76,7 +79,7 @@ class PatchPoolingCNN(Extractor): ...@@ -76,7 +79,7 @@ class PatchPoolingCNN(Extractor):
Returns Returns
------- -------
feature : :py:class:`numpy.ndarray` (floats) feature : :py:class:`numpy.ndarray` (floats)
The extracted features as a 1d array of size 320 The extracted features as a 1d array of size 256
""" """
...@@ -84,45 +87,49 @@ class PatchPoolingCNN(Extractor): ...@@ -84,45 +87,49 @@ class PatchPoolingCNN(Extractor):
pil_image = Image.fromarray(image.astype(np.uint8)) pil_image = Image.fromarray(image.astype(np.uint8))
input_image = self.data_transform(pil_image) input_image = self.data_transform(pil_image)
input_image = input_image.unsqueeze(0) input_image = input_image.unsqueeze(0)
# to be compliant with the loaded model, where weight and biases are torch.FloatTensor
input_image = input_image.float() input_image = input_image.float()
# obtain the features (to be pooled) from forward pass of network
_ , features = self.network.forward(Variable(input_image)) _ , features = self.network.forward(Variable(input_image))
# pool features through patch-level processing
features = self.conv_to_patch(features) features = self.conv_to_patch(features)
features = features.data.numpy().flatten() features = features.data.numpy().flatten()
return features.astype(np.float64) return features.astype(np.float64)
#---------------------------------------------------------- #------------------------------------------------------------------------------
def conv_to_patch(self, features): def conv_to_patch(self, features):
logger.debug("Shape of input features: {} {}".format(features.shape, features.squeeze().shape))
# parameters for the patch conversion # parameters for the patch conversion
stride = 4 # orig:4 #feat.shape[2]/4 stride = 4 # features.shape[2]/4
idx = 0 # purely for debugging
num_patch = features.shape[2]/stride # for debugging
feat_patch = torch.zeros(1, stride*stride*features.shape[1]) # idx = 0
# num_patch = features.shape[2]/stride
pooled_features = torch.zeros(1, stride*stride*features.shape[1])
# obtain patches by tesselation of feature maps # obtain patches by tesselation of feature maps
# pool linearized version of individual patches
for i in range(0, features.shape[2], stride): for i in range(0, features.shape[2], stride):
for j in range(0, features.shape[3], stride): for j in range(0, features.shape[3], stride):
feat_tmp = features[:, :, i:i+stride, j:j+stride] feat_temp = features[:, :, i:i+stride, j:j+stride]
feat_tmp = feat_tmp.contiguous().view(feat_tmp.size(0), -1) feat_temp = feat_temp.contiguous().view(feat_temp.size(0), -1)
feat_patch += feat_tmp pooled_features += feat_temp
idx += 1 # idx += 1
# normalize the patch vector # normalize the vector of pooled features
feat_patch = feat_patch/stride/stride pooled_features = pooled_features/stride/stride
logger.debug("Feat patch shape: {}".format(feat_patch.shape))
return pooled_features
return feat_patch #------------------------------------------------------------------------------
#---------------------------------------------------------- #------------------------------------------------------------------------------
# class LightCNN9Patch: it inherits the LightCNN-9 class, and returns the last # class LightCNN9Patch: it inherits the LightCNN-9 class from bob,
# conv layer features; instead of embeddings. # and returns the last conv layer features and embeddings.
class LightCNN9Patch(LightCNN9): class LightCNN9Patch(LightCNN9):
...@@ -131,7 +138,7 @@ class LightCNN9Patch(LightCNN9): ...@@ -131,7 +138,7 @@ class LightCNN9Patch(LightCNN9):
# do not change the init # do not change the init
super(LightCNN9Patch, self).__init__() super(LightCNN9Patch, self).__init__()
#---------------------------------------------------------- #------------------------------------------------------------------------------
def forward(self, x): def forward(self, x):
...@@ -147,14 +154,7 @@ class LightCNN9Patch(LightCNN9): ...@@ -147,14 +154,7 @@ class LightCNN9Patch(LightCNN9):
out = self.fc2(x) out = self.fc2(x)
return out, conv_out return out, conv_out
#---------------------------------------------------------- #------------------------------------------------------------------------------
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment