Commit 97d1c70a authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

Merge branch 'DeepMSPAD' into 'master'

Deep mspad

See merge request !24
parents 5335c905 23a4b535
Pipeline #28137 passed with stages
in 32 minutes and 52 seconds
include: 'https://gitlab.idiap.ch/bob/bob.devtools/raw/master/bob/devtools/data/gitlab-ci/single-package.yaml'
import torch
from torch import nn
from torchvision import models
import numpy as np
class DeepMSPAD(nn.Module):
""" Deep multispectral PAD algorithm
The initialization uses `Cross modality pre-training` idea from the following paper:
Wang L, Xiong Y, Wang Z, Qiao Y, Lin D, Tang X, Van Gool L. Temporal segment networks:
Towards good practices for deep action recognition. InEuropean conference on computer
vision 2016 Oct 8 (pp. 20-36). Springer, Cham.
Attributes:
pretrained: bool
if set `True` loads the pretrained vgg16 model.
vgg: :py:class:`torch.nn.Module`
The VGG16 model
relu: :py:class:`torch.nn.Module`
ReLU activation
enc: :py:class:`torch.nn.Module`
Uses the layers for feature extraction
linear1: :py:class:`torch.nn.Module`
Fully connected layer
linear2: :py:class:`torch.nn.Module`
Fully connected layer
dropout: :py:class:`torch.nn.Module`
Dropout layer
sigmoid: :py:class:`torch.nn.Module`
Sigmoid activation
"""
def __init__(self, pretrained=True, num_channels=4):
""" Init method
Parameters
----------
pretrained: bool
if set `True` loads the pretrained vgg16 model.
num_channels: int
Number of channels in the input
"""
super(DeepMSPAD, self).__init__()
vgg = models.vgg16(pretrained=pretrained)
features = list(vgg.features.children())
# temp layer to extract weights
temp_layer = features[0]
# Implements ``Cross modality pre-training``
# Mean of weight and bias for all filters
bias_values = temp_layer.bias.data.detach().numpy()
mean_weight = np.mean(temp_layer.weight.data.detach().numpy(),axis=1) # for 64 filters
new_weight = np.zeros((64,num_channels,3,3))
for i in range(num_channels):
new_weight[:,i,:,:]=mean_weight
# initialize new layer with required number of channels `num_channels`
features[0] = nn.Conv2d(num_channels, 64, kernel_size=(3, 3), stride=(1, 1), padding =(1, 1))
features[0].weight.data = torch.Tensor(new_weight)
features[0].bias.data = torch.Tensor(bias_values) #check
self.enc = nn.Sequential(*features)
self.linear1 = nn.Linear(25088,256)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.5)
self.linear2 = nn.Linear(256,1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
""" Propagate data through the network
Parameters
----------
x: :py:class:`torch.Tensor`
The data to forward through the network
Returns
-------
x: :py:class:`torch.Tensor`
The last layer of the network
"""
enc = self.enc(x)
x = enc.view(-1,25088)
x = self.linear1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.linear2(x)
x = self.sigmoid(x)
return x
import torch
from torch import nn
from torchvision import models
class DeepPixBiS(nn.Module):
""" The class defining Deep Pixelwise Binary Supervision for Face Presentation
Attack Detection:
Reference: Anjith George and Sébastien Marcel. "Deep Pixel-wise Binary Supervision for
Face Presentation Attack Detection." In 2019 International Conference on Biometrics (ICB).IEEE, 2019.
Attributes
----------
pretrained: bool
If set to `True` uses the pretrained DenseNet model as the base. If set to `False`, the network
will be trained from scratch.
default: True
"""
def __init__(self, pretrained=True):
""" Init function
Parameters
----------
pretrained: bool
If set to `True` uses the pretrained densenet model as the base. Else, it uses the default network
default: True
"""
super(DeepPixBiS, self).__init__()
dense = models.densenet161(pretrained=pretrained)
features = list(dense.features.children())
self.enc = nn.Sequential(*features[0:8])
self.dec=nn.Conv2d(384, 1, kernel_size=1, padding=0)
self.linear=nn.Linear(14*14,1)
def forward(self, x):
""" Propagate data through the network
Parameters
----------
img: :py:class:`torch.Tensor`
The data to forward through the network. Expects RGB image of size 3x224x224
Returns
-------
dec: :py:class:`torch.Tensor`
Binary map of size 1x14x14
op: :py:class:`torch.Tensor`
Final binary score.
"""
enc = self.enc(x)
dec=self.dec(enc)
dec=nn.Sigmoid()(dec)
dec_flat=dec.view(-1,14*14)
op=self.linear(dec_flat)
op=nn.Sigmoid()(op)
return dec,op
......@@ -6,6 +6,8 @@ from .LightCNN import LightCNN29v2
from .MCCNN import MCCNN
from .MCCNNv2 import MCCNNv2
from .FASNet import FASNet
from .DeepMSPAD import DeepMSPAD
from .DeepPixBiS import DeepPixBiS
from .DCGAN import DCGAN_generator
from .DCGAN import DCGAN_discriminator
......
......@@ -84,6 +84,23 @@ def test_architectures():
output = net.forward(t)
assert output.shape == torch.Size([1, 1])
#DeepMSPAD
a = numpy.random.rand(1, 8, 224, 224).astype("float32")
t = torch.from_numpy(a)
from ..architectures import DeepMSPAD
net = DeepMSPAD(pretrained=False, num_channels=8)
output = net.forward(t)
assert output.shape == torch.Size([1, 1])
#DeepPixBiS
a = numpy.random.rand(1, 3, 224, 224).astype("float32")
t = torch.from_numpy(a)
from ..architectures import DeepPixBiS
net = DeepPixBiS(pretrained=True)
output = net.forward(t)
assert output[0].shape == torch.Size([1, 1, 14, 14])
assert output[1].shape == torch.Size([1, 1])
# DCGAN
d = numpy.random.rand(1, 3, 64, 64).astype("float32")
t = torch.from_numpy(d)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment