Skip to content
Snippets Groups Projects
Commit d92971c6 authored by Anjith GEORGE's avatar Anjith GEORGE
Browse files

Added FASNet architecture

parent 404fa784
No related branches found
No related tags found
1 merge request!21Added FASNet architecture
Checking pipeline status
import torch
from torch import nn
from torchvision import models
class FASNet(nn.Module):
"""PyTorch Reimplementation of Lucena, Oeslle, et al. "Transfer learning using
convolutional neural networks for face anti-spoofing."
International Conference Image Analysis and Recognition. Springer, Cham, 2017.
eferenced from keras implementation: https://github.com/OeslleLucena/FASNet
Attributes:
pretrained: bool
if set `True` loads the pretrained vgg16 model.
vgg: :py:class:`torch.nn.Module`
The VGG16 model
relu: :py:class:`torch.nn.Module`
ReLU activation
enc: :py:class:`torch.nn.Module`
Uses the layers for feature extraction
linear1: :py:class:`torch.nn.Module`
Fully connected layer
linear2: :py:class:`torch.nn.Module`
Fully connected layer
dropout: :py:class:`torch.nn.Module`
Dropout layer
sigmoid: :py:class:`torch.nn.Module`
Sigmoid activation
"""
def __init__(self, pretrained=True):
""" Init method
Parameters
----------
pretrained: bool
if set `True` loads the pretrained vgg16 model.
"""
super(FASNet, self).__init__()
vgg = models.vgg16(pretrained=pretrained)
features = list(vgg.features.children())
self.enc = nn.Sequential(*features)
self.linear1=nn.Linear(25088,256)
self.relu=nn.ReLU()
self.dropout= nn.Dropout(p=0.5)
self.linear2=nn.Linear(256,1)
self.sigmoid= nn.Sigmoid()
def forward(self, x):
""" Propagate data through the network
Parameters
----------
x: :py:class:`torch.Tensor`
The data to forward through the network
Returns
-------
x: :py:class:`torch.Tensor`
The last layer of the network
"""
enc = self.enc(x)
x=enc.view(-1,25088)
x=self.linear1(x)
x=self.relu(x)
x=self.dropout(x)
x=self.linear2(x)
x=self.sigmoid(x)
return x
...@@ -5,6 +5,7 @@ from .LightCNN import LightCNN29 ...@@ -5,6 +5,7 @@ from .LightCNN import LightCNN29
from .LightCNN import LightCNN29v2 from .LightCNN import LightCNN29v2
from .MCCNN import MCCNN from .MCCNN import MCCNN
from .MCCNNv2 import MCCNNv2 from .MCCNNv2 import MCCNNv2
from .FASNet import FASNet
from .DCGAN import DCGAN_generator from .DCGAN import DCGAN_generator
from .DCGAN import DCGAN_discriminator from .DCGAN import DCGAN_discriminator
......
...@@ -76,6 +76,14 @@ def test_architectures(): ...@@ -76,6 +76,14 @@ def test_architectures():
output = net.forward(t) output = net.forward(t)
assert output.shape == torch.Size([1, 1]) assert output.shape == torch.Size([1, 1])
#FASNet
a = numpy.random.rand(1, 3, 224, 224).astype("float32")
t = torch.from_numpy(a)
from ..architectures import FASNet
net = FASNet(pretrained=False)
output = net.forward(t)
assert output.shape == torch.Size([1, 1])
# DCGAN # DCGAN
d = numpy.random.rand(1, 3, 64, 64).astype("float32") d = numpy.random.rand(1, 3, 64, 64).astype("float32")
t = torch.from_numpy(d) t = torch.from_numpy(d)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment