Skip to content
Snippets Groups Projects
Commit 9a45a75a authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

added DeepPixBiS extactor and corresponding test

parent fdabcc91
No related branches found
No related tags found
1 merge request!40Deeppixbis extractor
Pipeline #34648 failed
import numpy as np
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from bob.learn.pytorch.architectures import DeepPixBiS
from bob.bio.base.extractor import Extractor
import logging
logger = logging.getLogger("bob.learn.pytorch")
class DeepPixBiSExtractor(Extractor):
""" The class implementing the DeepPixBiS score computation.
Attributes
----------
network: :py:class:`torch.nn.Module`
The network architecture
transforms: :py:mod:`torchvision.transforms`
The transform from numpy.array to torch.Tensor
"""
def __init__(self, transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]), model_file=None, scoring_method='pixel_mean'):
""" Init method
Parameters
----------
model_file: str
The path of the trained PAD network to load
transforms: :py:mod:`torchvision.transforms`
Tranform to be applied on the image
scoring_method: str
The scoring method to be used to get the final score,
available methods are ['pixel_mean','binary','combined'].
"""
Extractor.__init__(self, skip_extractor_training=True)
# model
self.transforms = transforms
self.network = DeepPixBiS(pretrained=True)
self.scoring_method = scoring_method
self.available_scoring_methods=['pixel_mean','binary','combined']
logger.debug('Scoring method is : {}'.format(self.scoring_method.upper()))
if model_file is None:
# do nothing (used mainly for unit testing)
logger.debug("No pretrained file provided")
pass
else:
# With the new training
logger.debug('Starting to load the pretrained PAD model')
try:
cp = torch.load(model_file)
except:
try:
cp= torch.load(model_file,map_location=lambda storage,loc:storage)
except:
raise ValueError('Could not load the model')
if 'state_dict' in cp:
self.network.load_state_dict(cp['state_dict'])
else: ## check this part
self.network.load_state_dict(cp)
logger.debug('Loaded the pretrained PAD model')
self.network.eval()
def __call__(self, image):
""" Extract features from an image
Parameters
----------
image : 3D :py:class:`numpy.ndarray`
The image to extract the score from. Its size must be 3x224x224;
Returns
-------
output : float
The extracted feature is a scalar values ~1 for bonafide and ~0 for PAs
"""
input_image = np.rollaxis(np.rollaxis(image, 2),2) # changes from CxHxW to HxWxC
input_image = self.transforms(input_image)
input_image = input_image.unsqueeze(0)
output = self.network.forward(Variable(input_image))
output_pixel = output[0].data.numpy().flatten()
output_binary = output[1].data.numpy().flatten()
if self.scoring_method=='pixel_mean':
score=np.mean(output_pixel)
elif self.scoring_method=='binary':
score=np.mean(output_binary)
elif self.scoring_method=='combined':
score= (np.mean(output_pixel)+np.mean(output_binary))/2.0
else:
raise ValueError('Scoring method {} is not implemented.'.format(self.scoring_method))
# output is a scalar score
return score
......@@ -535,6 +535,14 @@ def test_extractors():
output = extractor(data)
assert output.shape[0] == 1
# DeepPixBiS
from ..extractor.image import DeepPixBiSExtractor
extractor = DeepPixBiSExtractor(scoring_method='pixel_mean')
# this architecture expects color images of size 3x224x224
data = numpy.random.rand(3, 224, 224).astype("uint8")
output = extractor(data)
assert output.shape[0] == 1
# MCDeepPixBiS
from ..extractor.image import MCDeepPixBiSExtractor
extractor = MCDeepPixBiSExtractor(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment