Skip to content
Snippets Groups Projects
Commit f2da377a authored by Anjith GEORGE's avatar Anjith GEORGE
Browse files

Added extended version of MCCNN

parent 31bbafd2
Branches
Tags
1 merge request!17Mccn extended
Pipeline #26548 failed
#!/usr/bin/env python
# encoding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import pkg_resources
import bob.extension.download
import bob.io.base
from .utils import MaxFeatureMap
from .utils import group
from .utils import resblock
class MCCNN(nn.Module):
""" The class defining the MCCNN
This class implements the MCCNN for multi-channel PAD
Attributes
----------
"""
def __init__(self, block=resblock, layers=[1, 2, 3, 4], num_channels=4):
""" Init function
Parameters
----------
num_channels: int
The number of channels present in the input
"""
super(MCCNN, self).__init__()
self.num_channels=num_channels
self.lcnn_layers=['conv1','block1','group1','block2', 'group2','block3','group3','block4','group4','fc']
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
# newly added FC layers
self.linear1fc=nn.Linear(256*num_channels,10)
self.linear2fc=nn.Linear(10,1)
# add modules
module_dict={}
for i in range(self.num_channels):
m_dict={}
m_dict['conv1'] = MaxFeatureMap(1, 48, 5, 1, 2)
m_dict['block1'] = self._make_layer(block, layers[0], 48, 48)
m_dict['group1'] = group(48, 96, 3, 1, 1)
m_dict['block2'] = self._make_layer(block, layers[1], 96, 96)
m_dict['group2'] = group(96, 192, 3, 1, 1)
m_dict['block3'] = self._make_layer(block, layers[2], 192, 192)
m_dict['group3'] = group(192, 128, 3, 1, 1)
m_dict['block4'] = self._make_layer(block, layers[3], 128, 128)
m_dict['group4'] = group(128, 128, 3, 1, 1)
m_dict['fc'] = MaxFeatureMap(8*8*128, 256, type=0)
# ch_0_should be the anchor
for layer in self.lcnn_layers:
layer_name="ch_{}_".format(i)+layer
module_dict[layer_name] = m_dict[layer]
self.layer_dict = nn.ModuleDict(module_dict)
# check for pretrained model
light_cnn_model_file = os.path.join(MCCNN.get_mccnnpath(), "LightCNN_29Layers_checkpoint.pth.tar")
url='https://www.idiap.ch/software/bob/data/bob/bob.learn.pytorch/master/LightCNN_29Layers_checkpoint.pth.tar'
print("light_cnn_model_file",light_cnn_model_file)
if not os.path.exists(light_cnn_model_file):
bob.io.base.create_directories_safe(os.path.split(light_cnn_model_file)[0])
print('Downloading the LightCNN model')
bob.extension.download.download_file(url,light_cnn_model_file)
print('Downloaded LightCNN model to {}'.format(light_cnn_model_file))
## Loding the pretrained model for ch_0
self.load_state_dict(self.get_model_state_dict(light_cnn_model_file),strict=False)
# copy over the weights to all other layers
for layer in self.lcnn_layers:
for i in range(1, self.num_channels): # except for 0 th channel
self.layer_dict["ch_{}_".format(i)+layer].load_state_dict(self.layer_dict["ch_0_"+layer].state_dict())
def _make_layer(self, block, num_blocks, in_channels, out_channels):
"""
Parameters
----------
"""
layers = []
for i in range(0, num_blocks):
layers.append(block(in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, img):
""" Propagate data through the network
Parameters
----------
img: :py:class:`torch.Tensor`
The data to forward through the network. Image of size num_channelsx128x128
Returns
-------
output: :py:class:`torch.Tensor`
score
"""
embeddings=[]
for i in range(self.num_channels):
x=img[:,i,:,:].unsqueeze(1) # the image for the specific channel
x = self.layer_dict["ch_{}_".format(i)+"conv1"](x)
x = self.pool1(x)
x = self.layer_dict["ch_{}_".format(i)+"block1"](x)
x = self.layer_dict["ch_{}_".format(i)+"group1"](x)
x = self.pool2(x)
x = self.layer_dict["ch_{}_".format(i)+"block2"](x)
x = self.layer_dict["ch_{}_".format(i)+"group2"](x)
x = self.pool3(x)
x = self.layer_dict["ch_{}_".format(i)+"block3"](x)
x = self.layer_dict["ch_{}_".format(i)+"group3"](x)
x = self.layer_dict["ch_{}_".format(i)+"block4"](x)
x = self.layer_dict["ch_{}_".format(i)+"group4"](x)
x = self.pool4(x)
x = x.view(x.size(0), -1)
fc = self.layer_dict["ch_{}_".format(i)+"fc"](x)
fc = F.dropout(fc, training=self.training)
embeddings.append(fc)
merged = torch.cat(embeddings, 1)
output = self.linear1fc(merged)
output = nn.Sigmoid()(output)
output = self.linear2fc(output)
if self.training:
output=nn.Sigmoid()(output)
return output
@staticmethod
def get_mccnnpath():
import pkg_resources
return pkg_resources.resource_filename('bob.learn.pytorch', 'models')
def get_model_state_dict(self,pretrained_model_path):
""" The class to load pretrained LightCNN model
Attributes
----------
pretrained_model_path: str
Absolute path to the LightCNN model file
new_state_dict:pyclass:dict
Dictionary with LightCNN weights
"""
checkpoint = torch.load(pretrained_model_path,map_location=lambda storage,loc:storage)
start_epoch = checkpoint['epoch']
state_dict = checkpoint['state_dict']
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = 'layer_dict.ch_0_'+k[7:] # remove `module.`
new_state_dict[name] = v
# load params
return new_state_dict
......@@ -3,6 +3,7 @@ from .CASIANet import CASIANet
from .LightCNN import LightCNN9
from .LightCNN import LightCNN29
from .LightCNN import LightCNN29v2
from .MCCNN import MCCNN
from .DCGAN import DCGAN_generator
from .DCGAN import DCGAN_discriminator
......
......@@ -60,6 +60,14 @@ def test_architectures():
assert output.shape == torch.Size([1, 79077])
assert emdedding.shape == torch.Size([1, 256])
# MCCNN
a = numpy.random.rand(1, 4, 128, 128).astype("float32")
t = torch.from_numpy(a)
from ..architectures import MCCNN
net = MCCNN(num_channels=4)
output = net.forward(t)
assert output.shape == torch.Size([1, 1])
# DCGAN
d = numpy.random.rand(1, 3, 64, 64).astype("float32")
t = torch.from_numpy(d)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment