Skip to content
Snippets Groups Projects
Commit 8a158686 authored by Anjith GEORGE's avatar Anjith GEORGE Committed by Anjith GEORGE
Browse files

WIP, MCCNN extractor

parent 9554f709
No related branches found
No related tags found
No related merge requests found
Pipeline #26203 failed
import numpy as np
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from bob.learn.pytorch.architectures import MCCNN
from bob.bio.base.extractor import Extractor
#TODO: Clean up
class MCCNNExtractor(Extractor):
""" The class implementing the MC-CNN score computation.
Attributes
----------
network: :py:class:`torch.nn.Module`
The network architecture
transforms: :py:mod:`torchvision.transforms`
The transform from numpy.array to torch.Tensor
"""
def __init__(self, considered_modalities=['C','D','I','T'], model_file=None):
""" Init method
Parameters
----------
considered_modalities: list
The list of modalities used C,D,I,T represents color, depth , infrared and thermal respectively
pretrained_lightCNN_modelpath: str
Path to the Pretrained LightCNN model
NB: There are two model files here; one is the pretrained Light CNN model which is used as the base network; then there is another model file
specifically trained for PAD (model_file) which contains the new adapted layers and the fully connected layers.
model_file: str
The path of the trained PAD network to load
"""
Extractor.__init__(self, skip_extractor_training=True)
# model
self.network = MCCNN(considered_modalities=considered_modalities)#.net
#self.network=self.network.to(device)
if model_file is None:
# do nothing (used mainly for unit testing)
print("No pretrained file")
pass
else:
# Old approach
# state_dict = torch.load(model_file,map_location=lambda storage,loc:storage)
# self.network.load_state_dict(state_dict)
# With the new training
cp = torch.load(model_file)
if 'state_dict' in cp:
self.network.load_state_dict(cp['state_dict'])
print('Loaded the pretrained PAD model')
self.network.net.eval()
# image pre-processing
self.transforms= transforms.Compose([transforms.ToTensor()])
def __call__(self, image):
""" Extract features from an image
Parameters
----------
image : 3D :py:class:`numpy.ndarray` (floats)
The multi-channel image to extract the score from. Its size must be 4x128x128;
The channels should be ordered in C D I T order (color, depth, infrared and thermal respectively)
Returns
-------
output : float
The extracted feature is a scalar values ~0 for bonafide and ~1 for PAs
"""
input_image = np.rollaxis(np.rollaxis(image, 2),2)
input_image = self.transforms(input_image)
input_image = input_image.unsqueeze(0)
#print("input_image",input_image.shape)
output = self.network.forward(Variable(input_image))
output = output.data.numpy().flatten()
# output is a scalar score
return output
from .CNN8 import CNN8Extractor from .CNN8 import CNN8Extractor
from .CasiaNet import CasiaNetExtractor from .CasiaNet import CasiaNetExtractor
from .MCCNN import MCCNNExtractor
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
def __appropriate__(*args): def __appropriate__(*args):
...@@ -19,6 +20,7 @@ def __appropriate__(*args): ...@@ -19,6 +20,7 @@ def __appropriate__(*args):
__appropriate__( __appropriate__(
CNN8Extractor, CNN8Extractor,
CasiaNetExtractor, CasiaNetExtractor,
MCCNNExtractor,
) )
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
......
...@@ -32,3 +32,17 @@ def test_casianet(): ...@@ -32,3 +32,17 @@ def test_casianet():
data = numpy.random.rand(3, 128, 128).astype("float32") data = numpy.random.rand(3, 128, 128).astype("float32")
output = extractor(data) output = extractor(data)
assert output.shape[0] == 320 assert output.shape[0] == 320
def test_mccnn():
""" test for the MCCNN architecture
this architecture takes 4x128x128 images as input
output a single score
"""
from . import MCCNNExtractor
extractor = MCCNNExtractor()
# this architecture expects 4x128x128 Multi channel images
data = numpy.random.rand(4, 128, 128).astype("float32")
output = extractor(data)
assert output.shape[0] == 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment