Skip to content
Snippets Groups Projects
Commit e730e352 authored by Anjith GEORGE's avatar Anjith GEORGE
Browse files

Added extractors and unit tests for MCCNN

parent 1d21a5fc
No related branches found
No related tags found
1 merge request!18MCCNN trainer
import numpy
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from bob.learn.pytorch.architectures import MCCNN
from bob.bio.base.extractor import Extractor
import logging
logger = logging.getLogger("bob.learn.pytorch")
class MCCNNExtractor(Extractor):
""" The class implementing the MC-CNN score computation.
Attributes
----------
network: :py:class:`torch.nn.Module`
The network architecture
transforms: :py:mod:`torchvision.transforms`
The transform from numpy.array to torch.Tensor
"""
def __init__(self, num_channels=4, transforms = transforms.Compose([transforms.ToTensor()]), model_file=None):
""" Init method
Parameters
----------
num_channels: int
The number of channels present in the input
model_file: str
The path of the trained PAD network to load
transforms: :py:mod:`torchvision.transforms`
tranform to be applied on the image
"""
Extractor.__init__(self, skip_extractor_training=True)
# model
self.transforms = transforms
self.network = MCCNN(num_channels=num_channels)
#self.network=self.network.to(device)
if model_file is None:
# do nothing (used mainly for unit testing)
logger.info("No pretrained file provided")
pass
else:
# With the new training
cp = torch.load(model_file)
if 'state_dict' in cp:
self.network.load_state_dict(cp['state_dict'])
logger.info('Loaded the pretrained PAD model')
self.network.eval()
def __call__(self, image):
""" Extract features from an image
Parameters
----------
image : 3D :py:class:`numpy.ndarray` (floats)
The multi-channel image to extract the score from. Its size must be num_channelsx128x128;
Returns
-------
output : float
The extracted feature is a scalar values ~1 for bonafide and ~0 for PAs
"""
input_image = np.rollaxis(np.rollaxis(image, 2),2) # changes to 128x128xnum_channels
input_image = self.transforms(input_image)
input_image = input_image.unsqueeze(0)
output = self.network.forward(Variable(input_image))
output = output.data.numpy().flatten()
# output is a scalar score
return output
from .LightCNN9 import LightCNN9Extractor from .LightCNN9 import LightCNN9Extractor
from .LightCNN29 import LightCNN29Extractor from .LightCNN29 import LightCNN29Extractor
from .LightCNN29v2 import LightCNN29v2Extractor from .LightCNN29v2 import LightCNN29v2Extractor
from .MCCNN import MCCNNExtractor
__all__ = [_ for _ in dir() if not _.startswith('_')] __all__ = [_ for _ in dir() if not _.startswith('_')]
...@@ -321,6 +321,15 @@ def test_extractors(): ...@@ -321,6 +321,15 @@ def test_extractors():
output = extractor(data) output = extractor(data)
assert output.shape[0] == 256 assert output.shape[0] == 256
# MCCNN
from . import MCCNNExtractor
extractor = MCCNNExtractor(num_channels=4)
# this architecture expects num_channelsx128x128 Multi channel images
data = numpy.random.rand(4, 128, 128).astype("float32")
output = extractor(data)
assert output.shape[0] == 1
def test_two_layer_mlp(): def test_two_layer_mlp():
""" """
Test the TwoLayerMLP class. Test the TwoLayerMLP class.
...@@ -339,5 +348,5 @@ def test_two_layer_mlp(): ...@@ -339,5 +348,5 @@ def test_two_layer_mlp():
model.apply_sigmoid = False model.apply_sigmoid = False
output = model(batch) output = model(batch)
assert list(output.shape) == [10, 1] assert list(output.shape) == [10, 1]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment