From 8a158686576d437c5ae3abca63082c4c1333aeae Mon Sep 17 00:00:00 2001 From: ageorge <anjith.george@idiap.ch> Date: Mon, 21 Jan 2019 18:13:36 +0100 Subject: [PATCH] WIP, MCCNN extractor --- bob/ip/pytorch_extractor/MCCNN.py | 102 +++++++++++++++++++++++++++ bob/ip/pytorch_extractor/__init__.py | 2 + bob/ip/pytorch_extractor/test.py | 14 ++++ 3 files changed, 118 insertions(+) create mode 100644 bob/ip/pytorch_extractor/MCCNN.py diff --git a/bob/ip/pytorch_extractor/MCCNN.py b/bob/ip/pytorch_extractor/MCCNN.py new file mode 100644 index 0000000..aa9bc5e --- /dev/null +++ b/bob/ip/pytorch_extractor/MCCNN.py @@ -0,0 +1,102 @@ +import numpy as np + +import torch +from torch.autograd import Variable + +import torchvision.transforms as transforms + +from bob.learn.pytorch.architectures import MCCNN +from bob.bio.base.extractor import Extractor + + +#TODO: Clean up + + +class MCCNNExtractor(Extractor): + """ The class implementing the MC-CNN score computation. + + Attributes + ---------- + network: :py:class:`torch.nn.Module` + The network architecture + transforms: :py:mod:`torchvision.transforms` + The transform from numpy.array to torch.Tensor + + """ + + def __init__(self, considered_modalities=['C','D','I','T'], model_file=None): + """ Init method + + Parameters + ---------- + considered_modalities: list + The list of modalities used C,D,I,T represents color, depth , infrared and thermal respectively + pretrained_lightCNN_modelpath: str + Path to the Pretrained LightCNN model + NB: There are two model files here; one is the pretrained Light CNN model which is used as the base network; then there is another model file + specifically trained for PAD (model_file) which contains the new adapted layers and the fully connected layers. + model_file: str + The path of the trained PAD network to load + + """ + + Extractor.__init__(self, skip_extractor_training=True) + + # model + self.network = MCCNN(considered_modalities=considered_modalities)#.net + + #self.network=self.network.to(device) + + if model_file is None: + # do nothing (used mainly for unit testing) + + print("No pretrained file") + pass + else: + + # Old approach + + # state_dict = torch.load(model_file,map_location=lambda storage,loc:storage) + + # self.network.load_state_dict(state_dict) + + # With the new training + cp = torch.load(model_file) + if 'state_dict' in cp: + self.network.load_state_dict(cp['state_dict']) + + print('Loaded the pretrained PAD model') + + self.network.net.eval() + + # image pre-processing + self.transforms= transforms.Compose([transforms.ToTensor()]) + + def __call__(self, image): + """ Extract features from an image + + Parameters + ---------- + image : 3D :py:class:`numpy.ndarray` (floats) + The multi-channel image to extract the score from. Its size must be 4x128x128; + The channels should be ordered in C D I T order (color, depth, infrared and thermal respectively) + + Returns + ------- + output : float + The extracted feature is a scalar values ~0 for bonafide and ~1 for PAs + + """ + + input_image = np.rollaxis(np.rollaxis(image, 2),2) + input_image = self.transforms(input_image) + input_image = input_image.unsqueeze(0) + + #print("input_image",input_image.shape) + + output = self.network.forward(Variable(input_image)) + output = output.data.numpy().flatten() + + # output is a scalar score + + return output diff --git a/bob/ip/pytorch_extractor/__init__.py b/bob/ip/pytorch_extractor/__init__.py index d3b6b9c..3bbb43c 100755 --- a/bob/ip/pytorch_extractor/__init__.py +++ b/bob/ip/pytorch_extractor/__init__.py @@ -1,5 +1,6 @@ from .CNN8 import CNN8Extractor from .CasiaNet import CasiaNetExtractor +from .MCCNN import MCCNNExtractor # gets sphinx autodoc done right - don't remove it def __appropriate__(*args): @@ -19,6 +20,7 @@ def __appropriate__(*args): __appropriate__( CNN8Extractor, CasiaNetExtractor, + MCCNNExtractor, ) # gets sphinx autodoc done right - don't remove it diff --git a/bob/ip/pytorch_extractor/test.py b/bob/ip/pytorch_extractor/test.py index 6411890..95a6823 100644 --- a/bob/ip/pytorch_extractor/test.py +++ b/bob/ip/pytorch_extractor/test.py @@ -32,3 +32,17 @@ def test_casianet(): data = numpy.random.rand(3, 128, 128).astype("float32") output = extractor(data) assert output.shape[0] == 320 + + +def test_mccnn(): + """ test for the MCCNN architecture + + this architecture takes 4x128x128 images as input + output a single score + """ + from . import MCCNNExtractor + extractor = MCCNNExtractor() + # this architecture expects 4x128x128 Multi channel images + data = numpy.random.rand(4, 128, 128).astype("float32") + output = extractor(data) + assert output.shape[0] == 1 -- GitLab