Skip to content
Snippets Groups Projects
Commit 23a4b535 authored by Anjith GEORGE's avatar Anjith GEORGE Committed by Anjith GEORGE
Browse files

Added DeepPixBis architecture for PAD

parent 2c7b52d3
No related branches found
No related tags found
1 merge request!24Deep mspad
Pipeline #28113 passed
import torch
from torch import nn
from torchvision import models
class DeepPixBiS(nn.Module):
""" The class defining Deep Pixelwise Binary Supervision for Face Presentation
Attack Detection:
Reference: Anjith George and Sébastien Marcel. "Deep Pixel-wise Binary Supervision for
Face Presentation Attack Detection." In 2019 International Conference on Biometrics (ICB).IEEE, 2019.
Attributes
----------
pretrained: bool
If set to `True` uses the pretrained DenseNet model as the base. If set to `False`, the network
will be trained from scratch.
default: True
"""
def __init__(self, pretrained=True):
""" Init function
Parameters
----------
pretrained: bool
If set to `True` uses the pretrained densenet model as the base. Else, it uses the default network
default: True
"""
super(DeepPixBiS, self).__init__()
dense = models.densenet161(pretrained=pretrained)
features = list(dense.features.children())
self.enc = nn.Sequential(*features[0:8])
self.dec=nn.Conv2d(384, 1, kernel_size=1, padding=0)
self.linear=nn.Linear(14*14,1)
def forward(self, x):
""" Propagate data through the network
Parameters
----------
img: :py:class:`torch.Tensor`
The data to forward through the network. Expects RGB image of size 3x224x224
Returns
-------
dec: :py:class:`torch.Tensor`
Binary map of size 1x14x14
op: :py:class:`torch.Tensor`
Final binary score.
"""
enc = self.enc(x)
dec=self.dec(enc)
dec=nn.Sigmoid()(dec)
dec_flat=dec.view(-1,14*14)
op=self.linear(dec_flat)
op=nn.Sigmoid()(op)
return dec,op
......@@ -7,6 +7,7 @@ from .MCCNN import MCCNN
from .MCCNNv2 import MCCNNv2
from .FASNet import FASNet
from .DeepMSPAD import DeepMSPAD
from .DeepPixBiS import DeepPixBiS
from .DCGAN import DCGAN_generator
from .DCGAN import DCGAN_discriminator
......
......@@ -92,6 +92,15 @@ def test_architectures():
output = net.forward(t)
assert output.shape == torch.Size([1, 1])
#DeepPixBiS
a = numpy.random.rand(1, 3, 224, 224).astype("float32")
t = torch.from_numpy(a)
from ..architectures import DeepPixBiS
net = DeepPixBiS(pretrained=True)
output = net.forward(t)
assert output[0].shape == torch.Size([1, 1, 14, 14])
assert output[1].shape == torch.Size([1, 1])
# DCGAN
d = numpy.random.rand(1, 3, 64, 64).astype("float32")
t = torch.from_numpy(d)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment