Skip to content
Snippets Groups Projects
Commit 8a158686 authored by Anjith GEORGE's avatar Anjith GEORGE Committed by Anjith GEORGE
Browse files

WIP, MCCNN extractor

parent 9554f709
Branches mccnn_rebase
No related tags found
No related merge requests found
Pipeline #26203 failed
import numpy as np
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from bob.learn.pytorch.architectures import MCCNN
from bob.bio.base.extractor import Extractor
#TODO: Clean up
class MCCNNExtractor(Extractor):
""" The class implementing the MC-CNN score computation.
Attributes
----------
network: :py:class:`torch.nn.Module`
The network architecture
transforms: :py:mod:`torchvision.transforms`
The transform from numpy.array to torch.Tensor
"""
def __init__(self, considered_modalities=['C','D','I','T'], model_file=None):
""" Init method
Parameters
----------
considered_modalities: list
The list of modalities used C,D,I,T represents color, depth , infrared and thermal respectively
pretrained_lightCNN_modelpath: str
Path to the Pretrained LightCNN model
NB: There are two model files here; one is the pretrained Light CNN model which is used as the base network; then there is another model file
specifically trained for PAD (model_file) which contains the new adapted layers and the fully connected layers.
model_file: str
The path of the trained PAD network to load
"""
Extractor.__init__(self, skip_extractor_training=True)
# model
self.network = MCCNN(considered_modalities=considered_modalities)#.net
#self.network=self.network.to(device)
if model_file is None:
# do nothing (used mainly for unit testing)
print("No pretrained file")
pass
else:
# Old approach
# state_dict = torch.load(model_file,map_location=lambda storage,loc:storage)
# self.network.load_state_dict(state_dict)
# With the new training
cp = torch.load(model_file)
if 'state_dict' in cp:
self.network.load_state_dict(cp['state_dict'])
print('Loaded the pretrained PAD model')
self.network.net.eval()
# image pre-processing
self.transforms= transforms.Compose([transforms.ToTensor()])
def __call__(self, image):
""" Extract features from an image
Parameters
----------
image : 3D :py:class:`numpy.ndarray` (floats)
The multi-channel image to extract the score from. Its size must be 4x128x128;
The channels should be ordered in C D I T order (color, depth, infrared and thermal respectively)
Returns
-------
output : float
The extracted feature is a scalar values ~0 for bonafide and ~1 for PAs
"""
input_image = np.rollaxis(np.rollaxis(image, 2),2)
input_image = self.transforms(input_image)
input_image = input_image.unsqueeze(0)
#print("input_image",input_image.shape)
output = self.network.forward(Variable(input_image))
output = output.data.numpy().flatten()
# output is a scalar score
return output
from .CNN8 import CNN8Extractor
from .CasiaNet import CasiaNetExtractor
from .MCCNN import MCCNNExtractor
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
......@@ -19,6 +20,7 @@ def __appropriate__(*args):
__appropriate__(
CNN8Extractor,
CasiaNetExtractor,
MCCNNExtractor,
)
# gets sphinx autodoc done right - don't remove it
......
......@@ -32,3 +32,17 @@ def test_casianet():
data = numpy.random.rand(3, 128, 128).astype("float32")
output = extractor(data)
assert output.shape[0] == 320
def test_mccnn():
""" test for the MCCNN architecture
this architecture takes 4x128x128 images as input
output a single score
"""
from . import MCCNNExtractor
extractor = MCCNNExtractor()
# this architecture expects 4x128x128 Multi channel images
data = numpy.random.rand(4, 128, 128).astype("float32")
output = extractor(data)
assert output.shape[0] == 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment