Skip to content
Snippets Groups Projects
Commit 360ebedb authored by Anjith GEORGE's avatar Anjith GEORGE
Browse files

Merge branch 'deeppixbis-extractor' into 'master'

Deeppixbis extractor

See merge request !40
parents 187b3076 887d9e54
Branches
Tags
1 merge request!40Deeppixbis extractor
Pipeline #36697 passed
import numpy as np
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from bob.learn.pytorch.architectures import DeepPixBiS
from bob.bio.base.extractor import Extractor
import logging
logger = logging.getLogger("bob.learn.pytorch")
class DeepPixBiSExtractor(Extractor):
""" The class implementing the DeepPixBiS score computation.
Attributes
----------
network: :py:class:`torch.nn.Module`
The network architecture
transforms: :py:mod:`torchvision.transforms`
The transform from numpy.array to torch.Tensor
"""
def __init__(self, transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]), model_file=None, scoring_method='pixel_mean'):
""" Init method
Parameters
----------
model_file: str
The path of the trained PAD network to load
transforms: :py:mod:`torchvision.transforms`
Tranform to be applied on the image
scoring_method: str
The scoring method to be used to get the final score,
available methods are ['pixel_mean','binary','combined'].
"""
Extractor.__init__(self, skip_extractor_training=True)
# model
self.transforms = transforms
self.network = DeepPixBiS(pretrained=True)
self.scoring_method = scoring_method
self.available_scoring_methods=['pixel_mean','binary','combined']
logger.debug('Scoring method is : {}'.format(self.scoring_method.upper()))
if model_file is None:
# do nothing (used mainly for unit testing)
logger.debug("No pretrained file provided")
pass
else:
logger.debug('Starting to load the pretrained PAD model')
try:
cp = torch.load(model_file)
except:
raise ValueError('Failed to load the model file : {}'.format(model_file))
if 'state_dict' in cp:
self.network.load_state_dict(cp['state_dict'])
else:
raise ValueError('Failed to load the state_dict for model file: {}'.format(model_file))
logger.debug('Loaded the pretrained PAD model')
self.network.eval()
def __call__(self, image):
""" Extract features from an image
Parameters
----------
image : 3D :py:class:`numpy.ndarray`
The image to extract the score from. Its size must be 3x224x224;
Returns
-------
output : float
The extracted feature is a scalar values ~1 for bonafide and ~0 for PAs
"""
input_image = np.rollaxis(np.rollaxis(image, 2),2) # changes from CxHxW to HxWxC
input_image = self.transforms(input_image)
input_image = input_image.unsqueeze(0)
output = self.network.forward(Variable(input_image))
output_pixel = output[0].data.numpy().flatten()
output_binary = output[1].data.numpy().flatten()
if self.scoring_method=='pixel_mean':
score=np.mean(output_pixel)
elif self.scoring_method=='binary':
score=np.mean(output_binary)
elif self.scoring_method=='combined':
score= (np.mean(output_pixel)+np.mean(output_binary))/2.0
else:
raise ValueError('Scoring method {} is not implemented.'.format(self.scoring_method))
# output is a scalar score
return np.reshape(score,(1,-1))
......@@ -5,6 +5,7 @@ from .MCCNN import MCCNNExtractor
from .MCCNNv2 import MCCNNv2Extractor
from .FASNet import FASNetExtractor
from .MCDeepPixBiS import MCDeepPixBiSExtractor
from .DeepPixBiS import DeepPixBiSExtractor
__all__ = [_ for _ in dir() if not _.startswith('_')]
......@@ -535,6 +535,14 @@ def test_extractors():
output = extractor(data)
assert output.shape[0] == 1
# DeepPixBiS
from ..extractor.image import DeepPixBiSExtractor
extractor = DeepPixBiSExtractor(scoring_method='pixel_mean')
# this architecture expects color images of size 3x224x224
data = numpy.random.rand(3, 224, 224).astype("uint8")
output = extractor(data)
assert output.shape[0] == 1
# MCDeepPixBiS
from ..extractor.image import MCDeepPixBiSExtractor
extractor = MCDeepPixBiSExtractor(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment