Commit 88962aa1 authored by Anjith GEORGE's avatar Anjith GEORGE

Added a generic image extractor

parent 5878e035
Pipeline #33034 failed with stage
in 19 minutes and 56 seconds
import numpy as np
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from import Extractor
import logging
logger = logging.getLogger("bob.learn.pytorch")
class GenericImageExtractor(Extractor):
""" The class implementing a generic image extractor
network: :py:class:`torch.nn.Module`
The network architecture
transforms: :py:mod:`torchvision.transforms`
The transform from numpy.array to torch.Tensor
def __init__(self, network, extractor_function, transforms = transforms.Compose([transforms.ToTensor()]), extractor_file=None,**kwargs):
""" Init method
`extractor_function` the function which provides the final score from the network output (TODO: May be move the network to config too)
`network` initialized architecture ; which is ready to load a pretrained model
model_file: str
The path of the trained PAD network to load
transforms: :py:mod:`torchvision.transforms`
Tranform to be applied on the image
scoring_method: str
The scoring method to be used to get the final score,
available methods are ['pixel_mean','binary','combined'].
Extractor.__init__(self, skip_extractor_training=True)
# model
self.transforms = transforms = network
def load(self, extractor_file):
"""Loads the parameters required for feature extraction from the extractor file.
This function usually is only useful in combination with the :py:meth:`train` function.
In this base class implementation, it does nothing.
extractor_file : str
The file to read the extractor from.
if extractor_file is None or '.hdf5' in extractor_file :
# do nothing (used mainly for unit testing)
logger.debug("No pretrained file provided in the config, will try loading the commandline path!")
# With the new training
logger.debug('Starting to load the pretrained PAD model')
cp = torch.load(extractor_file,map_location=lambda storage,loc:storage)
logger.debug('Loaded the model from cp.')
raise ValueError('Could not load pretrained model : {}'.format(extractor_file))
if 'state_dict' in cp:['state_dict'])
logger.debug('Loaded the model from state_dict.')
else: ## check this part
print('using the other way')
logger.debug('Loaded the model from cp directly.')
raise ValueError('Could not load the state dict : {}'.format(extractor_file))
logger.debug('Loaded the pretrained PAD model')
def __call__(self, image):
""" Extract features from an image
image : 3D :py:class:`numpy.ndarray`
The image to extract the score from. Its size must be 3x224x224;
output : float
The extracted feature is a scalar values ~1 for bonafide and ~0 for PAs
input_image = np.rollaxis(np.rollaxis(image, 2),2) # changes to 128x128xnum_channels
input_image = self.transforms(input_image)
input_image = input_image.unsqueeze(0)
output =
# output is a scalar score
return score
......@@ -5,6 +5,7 @@ from .MCCNN import MCCNNExtractor
from .MCCNNv2 import MCCNNv2Extractor
from .FASNet import FASNetExtractor
from .MCDeepPixBiS import MCDeepPixBiSExtractor
from .GenericImageExtractor import GenericImageExtractor
__all__ = [_ for _ in dir() if not _.startswith('_')]
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment