Commit 360ebedb authored by Anjith GEORGE's avatar Anjith GEORGE

Merge branch 'deeppixbis-extractor' into 'master'

Deeppixbis extractor

See merge request !40
parents 187b3076 887d9e54
Pipeline #36697 passed with stages
in 9 minutes and 18 seconds
import numpy as np
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from bob.learn.pytorch.architectures import DeepPixBiS
from import Extractor
import logging
logger = logging.getLogger("bob.learn.pytorch")
class DeepPixBiSExtractor(Extractor):
""" The class implementing the DeepPixBiS score computation.
network: :py:class:`torch.nn.Module`
The network architecture
transforms: :py:mod:`torchvision.transforms`
The transform from numpy.array to torch.Tensor
def __init__(self, transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]), model_file=None, scoring_method='pixel_mean'):
""" Init method
model_file: str
The path of the trained PAD network to load
transforms: :py:mod:`torchvision.transforms`
Tranform to be applied on the image
scoring_method: str
The scoring method to be used to get the final score,
available methods are ['pixel_mean','binary','combined'].
Extractor.__init__(self, skip_extractor_training=True)
# model
self.transforms = transforms = DeepPixBiS(pretrained=True)
self.scoring_method = scoring_method
logger.debug('Scoring method is : {}'.format(self.scoring_method.upper()))
if model_file is None:
# do nothing (used mainly for unit testing)
logger.debug("No pretrained file provided")
logger.debug('Starting to load the pretrained PAD model')
cp = torch.load(model_file)
raise ValueError('Failed to load the model file : {}'.format(model_file))
if 'state_dict' in cp:['state_dict'])
raise ValueError('Failed to load the state_dict for model file: {}'.format(model_file))
logger.debug('Loaded the pretrained PAD model')
def __call__(self, image):
""" Extract features from an image
image : 3D :py:class:`numpy.ndarray`
The image to extract the score from. Its size must be 3x224x224;
output : float
The extracted feature is a scalar values ~1 for bonafide and ~0 for PAs
input_image = np.rollaxis(np.rollaxis(image, 2),2) # changes from CxHxW to HxWxC
input_image = self.transforms(input_image)
input_image = input_image.unsqueeze(0)
output =
output_pixel = output[0].data.numpy().flatten()
output_binary = output[1].data.numpy().flatten()
if self.scoring_method=='pixel_mean':
elif self.scoring_method=='binary':
elif self.scoring_method=='combined':
score= (np.mean(output_pixel)+np.mean(output_binary))/2.0
raise ValueError('Scoring method {} is not implemented.'.format(self.scoring_method))
# output is a scalar score
return np.reshape(score,(1,-1))
...@@ -5,6 +5,7 @@ from .MCCNN import MCCNNExtractor ...@@ -5,6 +5,7 @@ from .MCCNN import MCCNNExtractor
from .MCCNNv2 import MCCNNv2Extractor from .MCCNNv2 import MCCNNv2Extractor
from .FASNet import FASNetExtractor from .FASNet import FASNetExtractor
from .MCDeepPixBiS import MCDeepPixBiSExtractor from .MCDeepPixBiS import MCDeepPixBiSExtractor
from .DeepPixBiS import DeepPixBiSExtractor
__all__ = [_ for _ in dir() if not _.startswith('_')] __all__ = [_ for _ in dir() if not _.startswith('_')]
...@@ -535,6 +535,14 @@ def test_extractors(): ...@@ -535,6 +535,14 @@ def test_extractors():
output = extractor(data) output = extractor(data)
assert output.shape[0] == 1 assert output.shape[0] == 1
# DeepPixBiS
from ..extractor.image import DeepPixBiSExtractor
extractor = DeepPixBiSExtractor(scoring_method='pixel_mean')
# this architecture expects color images of size 3x224x224
data = numpy.random.rand(3, 224, 224).astype("uint8")
output = extractor(data)
assert output.shape[0] == 1
# MCDeepPixBiS # MCDeepPixBiS
from ..extractor.image import MCDeepPixBiSExtractor from ..extractor.image import MCDeepPixBiSExtractor
extractor = MCDeepPixBiSExtractor( extractor = MCDeepPixBiSExtractor(
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment