Commit 03ca71be authored by Guillaume HEUSCH's avatar Guillaume HEUSCH

Merge branch 'mc_deep_pixbis' into 'master'

Mc deep pixbis

See merge request !30
parents f55be918 6d142c3b
Pipeline #29622 canceled with stage
in 8 minutes and 50 seconds
import torch
from torch import nn
from torchvision import models
import numpy as np
class MCDeepPixBiS(nn.Module):
""" The class defining Multi-Channel Deep Pixelwise Binary Supervision for Face Presentation
Attack Detection:
This extends the following paper to multi-channel/ multi-spectral images with cross modal pretraining.
Reference: Anjith George and Sébastien Marcel. "Deep Pixel-wise Binary Supervision for
Face Presentation Attack Detection." In 2019 International Conference on Biometrics (ICB).IEEE, 2019.
The initialization uses `Cross modality pre-training` idea from the following paper:
Wang L, Xiong Y, Wang Z, Qiao Y, Lin D, Tang X, Van Gool L. Temporal segment networks:
Towards good practices for deep action recognition. InEuropean conference on computer
vision 2016 Oct 8 (pp. 20-36). Springer, Cham.
Attributes
----------
pretrained: bool
If set to `True` uses the pretrained DenseNet model as the base. If set to `False`, the network
will be trained from scratch.
default: True
num_channels: int
Number of channels in the input.
"""
def __init__(self, pretrained=True, num_channels=4):
""" Init function
Parameters
----------
pretrained: bool
If set to `True` uses the pretrained densenet model as the base. Else, it uses the default network
default: True
num_channels: int
Number of channels in the input.
"""
super(MCDeepPixBiS, self).__init__()
dense = models.densenet161(pretrained=pretrained)
features = list(dense.features.children())
temp_layer = features[0]
# No bias in this architecture
mean_weight = np.mean(temp_layer.weight.data.detach().numpy(),axis=1) # for 96 filters
new_weight = np.zeros((96,num_channels,7,7))
for i in range(num_channels):
new_weight[:,i,:,:]=mean_weight
features[0]=nn.Conv2d(num_channels, 96, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
features[0].weight.data = torch.Tensor(new_weight)
self.enc = nn.Sequential(*features[0:8])
self.dec=nn.Conv2d(384, 1, kernel_size=1, padding=0)
self.linear=nn.Linear(14*14,1)
def forward(self, x):
""" Propagate data through the network
Parameters
----------
img: :py:class:`torch.Tensor`
The data to forward through the network. Expects Multi-channel images of size num_channelsx224x224
Returns
-------
dec: :py:class:`torch.Tensor`
Binary map of size 1x14x14
op: :py:class:`torch.Tensor`
Final binary score.
"""
enc = self.enc(x)
dec=self.dec(enc)
dec=nn.Sigmoid()(dec)
dec_flat=dec.view(-1,14*14)
op=self.linear(dec_flat)
op=nn.Sigmoid()(op)
return dec,op
...@@ -8,6 +8,7 @@ from .MCCNNv2 import MCCNNv2 ...@@ -8,6 +8,7 @@ from .MCCNNv2 import MCCNNv2
from .FASNet import FASNet from .FASNet import FASNet
from .DeepMSPAD import DeepMSPAD from .DeepMSPAD import DeepMSPAD
from .DeepPixBiS import DeepPixBiS from .DeepPixBiS import DeepPixBiS
from .MCDeepPixBiS import MCDeepPixBiS
from .DCGAN import DCGAN_generator from .DCGAN import DCGAN_generator
from .DCGAN import DCGAN_discriminator from .DCGAN import DCGAN_discriminator
......
import numpy as np
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from bob.learn.pytorch.architectures import MCDeepPixBiS
from bob.bio.base.extractor import Extractor
import logging
logger = logging.getLogger("bob.learn.pytorch")
class MCDeepPixBiSExtractor(Extractor):
""" The class implementing the score computation for MCDeepPixBiS architecture.
Attributes
----------
network: :py:class:`torch.nn.Module`
The network architecture
transforms: :py:mod:`torchvision.transforms`
The transform from numpy.array to torch.Tensor
scoring_method: str
The scoring method to be used to get the final score,
available methods are ['pixel_mean','binary','combined'].
"""
def __init__(self, transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.225])]), model_file=None, num_channels=8, scoring_method='pixel_mean'):
""" Init method
Parameters
----------
model_file: str
The path of the trained PAD network to load
transforms: :py:mod:`torchvision.transforms`
Tranform to be applied on the image
scoring_method: str
The scoring method to be used to get the final score,
available methods are ['pixel_mean','binary','combined'].
"""
Extractor.__init__(self, skip_extractor_training=True)
# model
self.transforms = transforms
self.scoring_method = scoring_method
self.num_channels =num_channels
self.network = MCDeepPixBiS(pretrained=False, num_channels=self.num_channels)
self.available_scoring_methods=['pixel_mean','binary','combined']
logger.debug('Scoring method is : {}'.format(self.scoring_method.upper()))
if model_file is None:
# do nothing (used mainly for unit testing)
logger.debug("No pretrained file provided")
pass
else:
# With the new training
logger.debug('Starting to load the pretrained PAD model')
try:
cp = torch.load(model_file)
except:
raise ValueError('Failed to load the model file : {}'.format(model_file))
if 'state_dict' in cp:
self.network.load_state_dict(cp['state_dict'])
else:
raise ValueError('Failed to load the state_dict for model file: {}'.format(model_file))
logger.debug('Loaded the pretrained PAD model')
self.network.eval()
def __call__(self, image):
""" Extract features from an image
Parameters
----------
image : 3D :py:class:`numpy.ndarray`
The image to extract the score from. Its size must be num_channelsx224x224;
Returns
-------
output : float
The extracted feature is a scalar values ~1 for bonafide and ~0 for PAs
"""
input_image = np.rollaxis(np.rollaxis(image, 2),2) # changes to 224x224xnum_channels
input_image = self.transforms(input_image)
input_image = input_image.unsqueeze(0)
output = self.network.forward(Variable(input_image))
output_pixel = output[0].data.numpy().flatten()
output_binary = output[1].data.numpy().flatten()
if self.scoring_method=='pixel_mean':
score=np.mean(output_pixel)
elif self.scoring_method=='binary':
score=np.mean(output_binary)
elif self.scoring_method=='combined':
score= (np.mean(output_pixel)+np.mean(output_binary))/2.0
else:
raise ValueError('Scoring method {} is not implemented.'.format(self.scoring_method))
# output is a scalar score
return np.reshape(score,(1,-1))
...@@ -4,6 +4,7 @@ from .LightCNN29v2 import LightCNN29v2Extractor ...@@ -4,6 +4,7 @@ from .LightCNN29v2 import LightCNN29v2Extractor
from .MCCNN import MCCNNExtractor from .MCCNN import MCCNNExtractor
from .MCCNNv2 import MCCNNv2Extractor from .MCCNNv2 import MCCNNv2Extractor
from .FASNet import FASNetExtractor from .FASNet import FASNetExtractor
from .MCDeepPixBiS import MCDeepPixBiSExtractor
__all__ = [_ for _ in dir() if not _.startswith('_')] __all__ = [_ for _ in dir() if not _.startswith('_')]
...@@ -95,6 +95,15 @@ def test_architectures(): ...@@ -95,6 +95,15 @@ def test_architectures():
assert output[0].shape == torch.Size([1, 1, 14, 14]) assert output[0].shape == torch.Size([1, 1, 14, 14])
assert output[1].shape == torch.Size([1, 1]) assert output[1].shape == torch.Size([1, 1])
#MCDeepPixBiS
a = numpy.random.rand(1, 8, 224, 224).astype("float32")
t = torch.from_numpy(a)
from ..architectures import MCDeepPixBiS
net = MCDeepPixBiS(pretrained=False,num_channels=8)
output = net.forward(t)
assert output[0].shape == torch.Size([1, 1, 14, 14])
assert output[1].shape == torch.Size([1, 1])
# DCGAN # DCGAN
d = numpy.random.rand(1, 3, 64, 64).astype("float32") d = numpy.random.rand(1, 3, 64, 64).astype("float32")
t = torch.from_numpy(d) t = torch.from_numpy(d)
...@@ -437,3 +446,11 @@ def test_extractors(): ...@@ -437,3 +446,11 @@ def test_extractors():
data = numpy.random.rand(3, 224, 224).astype("uint8") data = numpy.random.rand(3, 224, 224).astype("uint8")
output = extractor(data) output = extractor(data)
assert output.shape[0] == 1 assert output.shape[0] == 1
# MCDeepPixBiS
from ..extractor.image import MCDeepPixBiSExtractor
extractor = MCDeepPixBiSExtractor(num_channels=8, scoring_method='pixel_mean')
# this architecture expects multi-channel images of size num_channelsx224x224
data = numpy.random.rand(8, 224, 224).astype("uint8")
output = extractor(data)
assert output.shape[0] == 1
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment