diff --git a/bob/learn/pytorch/architectures/MCDeepPixBiS.py b/bob/learn/pytorch/architectures/MCDeepPixBiS.py new file mode 100644 index 0000000000000000000000000000000000000000..fea3fb1653b09816a6478cde2fc7b6fc9d57cf9a --- /dev/null +++ b/bob/learn/pytorch/architectures/MCDeepPixBiS.py @@ -0,0 +1,102 @@ +import torch +from torch import nn +from torchvision import models +import numpy as np + + +class MCDeepPixBiS(nn.Module): + + """ The class defining Multi-Channel Deep Pixelwise Binary Supervision for Face Presentation + Attack Detection: + + This extends the following paper to multi-channel/ multi-spectral images with cross modal pretraining. + + Reference: Anjith George and Sébastien Marcel. "Deep Pixel-wise Binary Supervision for + Face Presentation Attack Detection." In 2019 International Conference on Biometrics (ICB).IEEE, 2019. + + The initialization uses `Cross modality pre-training` idea from the following paper: + + Wang L, Xiong Y, Wang Z, Qiao Y, Lin D, Tang X, Van Gool L. Temporal segment networks: + Towards good practices for deep action recognition. InEuropean conference on computer + vision 2016 Oct 8 (pp. 20-36). Springer, Cham. + + + Attributes + ---------- + pretrained: bool + If set to `True` uses the pretrained DenseNet model as the base. If set to `False`, the network + will be trained from scratch. + default: True + num_channels: int + Number of channels in the input. + """ + + def __init__(self, pretrained=True, num_channels=4): + + """ Init function + + Parameters + ---------- + pretrained: bool + If set to `True` uses the pretrained densenet model as the base. Else, it uses the default network + default: True + num_channels: int + Number of channels in the input. + """ + super(MCDeepPixBiS, self).__init__() + + dense = models.densenet161(pretrained=pretrained) + + features = list(dense.features.children()) + + temp_layer = features[0] + + # No bias in this architecture + + mean_weight = np.mean(temp_layer.weight.data.detach().numpy(),axis=1) # for 96 filters + + new_weight = np.zeros((96,num_channels,7,7)) + + for i in range(num_channels): + new_weight[:,i,:,:]=mean_weight + + features[0]=nn.Conv2d(num_channels, 96, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + + features[0].weight.data = torch.Tensor(new_weight) + + self.enc = nn.Sequential(*features[0:8]) + + self.dec=nn.Conv2d(384, 1, kernel_size=1, padding=0) + + self.linear=nn.Linear(14*14,1) + + + def forward(self, x): + """ Propagate data through the network + + Parameters + ---------- + img: :py:class:`torch.Tensor` + The data to forward through the network. Expects Multi-channel images of size num_channelsx224x224 + + Returns + ------- + dec: :py:class:`torch.Tensor` + Binary map of size 1x14x14 + op: :py:class:`torch.Tensor` + Final binary score. + + """ + enc = self.enc(x) + + dec=self.dec(enc) + + dec=nn.Sigmoid()(dec) + + dec_flat=dec.view(-1,14*14) + + op=self.linear(dec_flat) + + op=nn.Sigmoid()(op) + + return dec,op diff --git a/bob/learn/pytorch/architectures/__init__.py b/bob/learn/pytorch/architectures/__init__.py index 966a5ec19859b8b1ca0bc70a52e6a69ecdc581e4..a487143aaae787e1e877e88394d97f4023114844 100644 --- a/bob/learn/pytorch/architectures/__init__.py +++ b/bob/learn/pytorch/architectures/__init__.py @@ -8,6 +8,7 @@ from .MCCNNv2 import MCCNNv2 from .FASNet import FASNet from .DeepMSPAD import DeepMSPAD from .DeepPixBiS import DeepPixBiS +from .MCDeepPixBiS import MCDeepPixBiS from .DCGAN import DCGAN_generator from .DCGAN import DCGAN_discriminator diff --git a/bob/learn/pytorch/extractor/image/MCDeepPixBiS.py b/bob/learn/pytorch/extractor/image/MCDeepPixBiS.py new file mode 100644 index 0000000000000000000000000000000000000000..1f8a28ebc0b5d40b447294fd39372473eff9d501 --- /dev/null +++ b/bob/learn/pytorch/extractor/image/MCDeepPixBiS.py @@ -0,0 +1,113 @@ +import numpy as np +import torch +from torch.autograd import Variable + +import torchvision.transforms as transforms + +from bob.learn.pytorch.architectures import MCDeepPixBiS +from bob.bio.base.extractor import Extractor + +import logging +logger = logging.getLogger("bob.learn.pytorch") + +class MCDeepPixBiSExtractor(Extractor): + """ The class implementing the score computation for MCDeepPixBiS architecture. + + Attributes + ---------- + network: :py:class:`torch.nn.Module` + The network architecture + transforms: :py:mod:`torchvision.transforms` + The transform from numpy.array to torch.Tensor + scoring_method: str + The scoring method to be used to get the final score, + available methods are ['pixel_mean','binary','combined']. + + """ + + def __init__(self, transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.225])]), model_file=None, num_channels=8, scoring_method='pixel_mean'): + + + """ Init method + + Parameters + ---------- + model_file: str + The path of the trained PAD network to load + transforms: :py:mod:`torchvision.transforms` + Tranform to be applied on the image + scoring_method: str + The scoring method to be used to get the final score, + available methods are ['pixel_mean','binary','combined']. + + """ + + Extractor.__init__(self, skip_extractor_training=True) + + # model + self.transforms = transforms + self.scoring_method = scoring_method + self.num_channels =num_channels + self.network = MCDeepPixBiS(pretrained=False, num_channels=self.num_channels) + self.available_scoring_methods=['pixel_mean','binary','combined'] + + logger.debug('Scoring method is : {}'.format(self.scoring_method.upper())) + + if model_file is None: + # do nothing (used mainly for unit testing) + logger.debug("No pretrained file provided") + pass + else: + + # With the new training + logger.debug('Starting to load the pretrained PAD model') + + try: + cp = torch.load(model_file) + except: + raise ValueError('Failed to load the model file : {}'.format(model_file)) + + if 'state_dict' in cp: + self.network.load_state_dict(cp['state_dict']) + else: + raise ValueError('Failed to load the state_dict for model file: {}'.format(model_file)) + + logger.debug('Loaded the pretrained PAD model') + + self.network.eval() + + def __call__(self, image): + """ Extract features from an image + + Parameters + ---------- + image : 3D :py:class:`numpy.ndarray` + The image to extract the score from. Its size must be num_channelsx224x224; + + Returns + ------- + output : float + The extracted feature is a scalar values ~1 for bonafide and ~0 for PAs + + """ + + input_image = np.rollaxis(np.rollaxis(image, 2),2) # changes to 224x224xnum_channels + input_image = self.transforms(input_image) + input_image = input_image.unsqueeze(0) + + output = self.network.forward(Variable(input_image)) + output_pixel = output[0].data.numpy().flatten() + output_binary = output[1].data.numpy().flatten() + + if self.scoring_method=='pixel_mean': + score=np.mean(output_pixel) + elif self.scoring_method=='binary': + score=np.mean(output_binary) + elif self.scoring_method=='combined': + score= (np.mean(output_pixel)+np.mean(output_binary))/2.0 + else: + raise ValueError('Scoring method {} is not implemented.'.format(self.scoring_method)) + + # output is a scalar score + + return np.reshape(score,(1,-1)) diff --git a/bob/learn/pytorch/extractor/image/__init__.py b/bob/learn/pytorch/extractor/image/__init__.py index d019435c0628ea6d1c30701809b2ac02d2e6dc4e..7f585c37add2605eeb983de33d9d49edcb4646fa 100644 --- a/bob/learn/pytorch/extractor/image/__init__.py +++ b/bob/learn/pytorch/extractor/image/__init__.py @@ -4,6 +4,7 @@ from .LightCNN29v2 import LightCNN29v2Extractor from .MCCNN import MCCNNExtractor from .MCCNNv2 import MCCNNv2Extractor from .FASNet import FASNetExtractor +from .MCDeepPixBiS import MCDeepPixBiSExtractor __all__ = [_ for _ in dir() if not _.startswith('_')] diff --git a/bob/learn/pytorch/test/test.py b/bob/learn/pytorch/test/test.py index 08e3a46f9cfd7c68a15e14a97125c9da80ea8e80..7a2de3bfd4e22532ac25ae222aa556c629a22046 100644 --- a/bob/learn/pytorch/test/test.py +++ b/bob/learn/pytorch/test/test.py @@ -95,6 +95,15 @@ def test_architectures(): assert output[0].shape == torch.Size([1, 1, 14, 14]) assert output[1].shape == torch.Size([1, 1]) + #MCDeepPixBiS + a = numpy.random.rand(1, 8, 224, 224).astype("float32") + t = torch.from_numpy(a) + from ..architectures import MCDeepPixBiS + net = MCDeepPixBiS(pretrained=False,num_channels=8) + output = net.forward(t) + assert output[0].shape == torch.Size([1, 1, 14, 14]) + assert output[1].shape == torch.Size([1, 1]) + # DCGAN d = numpy.random.rand(1, 3, 64, 64).astype("float32") t = torch.from_numpy(d) @@ -437,3 +446,11 @@ def test_extractors(): data = numpy.random.rand(3, 224, 224).astype("uint8") output = extractor(data) assert output.shape[0] == 1 + + # MCDeepPixBiS + from ..extractor.image import MCDeepPixBiSExtractor + extractor = MCDeepPixBiSExtractor(num_channels=8, scoring_method='pixel_mean') + # this architecture expects multi-channel images of size num_channelsx224x224 + data = numpy.random.rand(8, 224, 224).astype("uint8") + output = extractor(data) + assert output.shape[0] == 1 \ No newline at end of file