Skip to content
Snippets Groups Projects
Commit 01a78753 authored by Anjith GEORGE's avatar Anjith GEORGE
Browse files

Added version 2 of MCCNN with shared layers

parent aaf659c1
Branches
Tags
1 merge request!20Mccn trainer
Checking pipeline status
#!/usr/bin/env python
# encoding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import pkg_resources
import bob.extension.download
import bob.io.base
from .utils import MaxFeatureMap
from .utils import group
from .utils import resblock
import logging
logger = logging.getLogger("bob.learn.pytorch")
class MCCNNv2(nn.Module):
""" The class defining the MCCNNv2 the difference from MCCNN is that it uses shared layers for
layers which are not adapted. This avoids replicating shared layers.
Attributes
----------
num_channels: int
The number of channels present in the input
lcnn_layers: list
The adaptable layers present in the base LightCNN model
module_dict: dict
A dictionary containing module names and `torch.nn.Module` elements as key, value pairs.
layer_dict: :py:class:`torch.nn.ModuleDict`
Pytorch class containing the modules as a dictionary.
light_cnn_model_file: str
Absolute path to the pretrained LightCNN model file.
adapted_layers: str
The layers to be adapted in training, they are to be separated by '-'.
Example: 'conv1-block1-group1-ffc'; 'ffc' denotes final fully connected layers which
are adapted in all the cases.
url: str
The path to download the pretrained LightCNN model from.
"""
def __init__(self, block=resblock, layers=[1, 2, 3, 4], num_channels=4, adapted_layers = 'conv1-block1-group1-ffc', verbosity_level=2):
""" Init function
Parameters
----------
num_channels: int
The number of channels present in the input
adapted_layers: str
The layers to be adapted in training, they are to be separated by '-'.
Example: 'conv1-block1-group1-ffc'; 'ffc' denotes final fully connected layers which
are adapted in all the cases.
verbosity_level: int
Verbosity level.
"""
super(MCCNNv2, self).__init__()
self.num_channels=num_channels
self.lcnn_layers=['conv1','block1','group1','block2', 'group2','block3','group3','block4','group4','fc']
layers_present = self.lcnn_layers.copy()
layers_present.append('ffc')
# select the layers in the network to adapt
adapted_layers_list=adapted_layers.split('-')
assert('ffc' in adapted_layers_list)
assert(set(adapted_layers_list)<=set(layers_present)) # to ensure layer names are valid
self.shared_layers = list(set(layers_present) - set(adapted_layers_list)) # shared layers
self.domain_specific_layers= list(set(adapted_layers_list)-set(['ffc']))
logger.setLevel(verbosity_level)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
# newly added FC layers
self.linear1fc=nn.Linear(256*num_channels,10)
self.linear2fc=nn.Linear(10,1)
# add modules
module_dict={}
for i in range(self.num_channels):
m_dict={}
m_dict['conv1'] = MaxFeatureMap(1, 48, 5, 1, 2)
m_dict['block1'] = self._make_layer(block, layers[0], 48, 48)
m_dict['group1'] = group(48, 96, 3, 1, 1)
m_dict['block2'] = self._make_layer(block, layers[1], 96, 96)
m_dict['group2'] = group(96, 192, 3, 1, 1)
m_dict['block3'] = self._make_layer(block, layers[2], 192, 192)
m_dict['group3'] = group(192, 128, 3, 1, 1)
m_dict['block4'] = self._make_layer(block, layers[3], 128, 128)
m_dict['group4'] = group(128, 128, 3, 1, 1)
m_dict['fc'] = MaxFeatureMap(8*8*128, 256, type=0)
# ch_0_should be the anchor
for layer in self.domain_specific_layers: # needs copies for domain specific layers
layer_name="ch_{}_".format(i)+layer
module_dict[layer_name] = m_dict[layer]
m_dict={}
m_dict['conv1'] = MaxFeatureMap(1, 48, 5, 1, 2)
m_dict['block1'] = self._make_layer(block, layers[0], 48, 48)
m_dict['group1'] = group(48, 96, 3, 1, 1)
m_dict['block2'] = self._make_layer(block, layers[1], 96, 96)
m_dict['group2'] = group(96, 192, 3, 1, 1)
m_dict['block3'] = self._make_layer(block, layers[2], 192, 192)
m_dict['group3'] = group(192, 128, 3, 1, 1)
m_dict['block4'] = self._make_layer(block, layers[3], 128, 128)
m_dict['group4'] = group(128, 128, 3, 1, 1)
m_dict['fc'] = MaxFeatureMap(8*8*128, 256, type=0)
for layer in self.shared_layers: # shared layers have ch_0_ prefix to make loading from pretrained model easier.
layer_name="ch_0_"+layer
module_dict[layer_name] = m_dict[layer]
self.layer_dict = nn.ModuleDict(module_dict)
# check for pretrained model
light_cnn_model_file = os.path.join(MCCNNv2.get_mccnnv2path(), "LightCNN_29Layers_checkpoint.pth.tar")
url='http://www.idiap.ch/software/bob/data/bob/bob.learn.pytorch/master/LightCNN_29Layers_checkpoint.pth.tar'
logger.info("Light_cnn_model_file path: {}".format(light_cnn_model_file))
if not os.path.exists(light_cnn_model_file):
bob.io.base.create_directories_safe(os.path.split(light_cnn_model_file)[0])
logger.info('Downloading the LightCNN model')
bob.extension.download.download_file(url,light_cnn_model_file)
logger.info('Downloaded LightCNN model to location: {}'.format(light_cnn_model_file))
## Loding the pretrained model for ch_0
self.load_state_dict(self.get_model_state_dict(light_cnn_model_file),strict=False)
# copy over the weights to all other domain specific layers
for layer in self.domain_specific_layers:
for i in range(1, self.num_channels): # except for 0 th channel
self.layer_dict["ch_{}_".format(i)+layer].load_state_dict(self.layer_dict["ch_0_"+layer].state_dict())
def _make_layer(self, block, num_blocks, in_channels, out_channels):
""" makes multiple copies of the same base module
Parameters
----------
block: :py:class:`torch.nn.Module`
The base block to replicate
num_blocks: int
Number of copies of the block to be made
in_channels: int
Number of input channels for a block
out_channels: int
Number of output channels for a block
"""
layers = []
for i in range(0, num_blocks):
layers.append(block(in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, img):
""" Propagate data through the network
Parameters
----------
img: :py:class:`torch.Tensor`
The data to forward through the network. Image of size num_channelsx128x128
Returns
-------
output: :py:class:`torch.Tensor`
score
"""
embeddings=[]
for i in range(self.num_channels):
commom_layer = lambda x,y: x if self.lcnn_layers[y] in self.domain_specific_layers else 0
# for ll in range(0,10):
# logger.debug("ch_{}_".format(commom_layer(i,ll))+self.lcnn_layers[ll])
x=img[:,i,:,:].unsqueeze(1) # the image for the specific channel
x = self.layer_dict["ch_{}_".format(commom_layer(i,0))+self.lcnn_layers[0]](x)
x = self.pool1(x)
x = self.layer_dict["ch_{}_".format(commom_layer(i,1))+self.lcnn_layers[1]](x)
x = self.layer_dict["ch_{}_".format(commom_layer(i,2))+self.lcnn_layers[2]](x)
x = self.pool2(x)
x = self.layer_dict["ch_{}_".format(commom_layer(i,3))+self.lcnn_layers[3]](x)
x = self.layer_dict["ch_{}_".format(commom_layer(i,4))+self.lcnn_layers[4]](x)
x = self.pool3(x)
x = self.layer_dict["ch_{}_".format(commom_layer(i,5))+self.lcnn_layers[5]](x)
x = self.layer_dict["ch_{}_".format(commom_layer(i,6))+self.lcnn_layers[6]](x)
x = self.layer_dict["ch_{}_".format(commom_layer(i,7))+self.lcnn_layers[7]](x)
x = self.layer_dict["ch_{}_".format(commom_layer(i,8))+self.lcnn_layers[8]](x)
x = self.pool4(x)
x = x.view(x.size(0), -1)
fc = self.layer_dict["ch_{}_".format(commom_layer(i,9))+self.lcnn_layers[9]](x)
fc = F.dropout(fc, training=self.training)
embeddings.append(fc)
merged = torch.cat(embeddings, 1)
output = self.linear1fc(merged)
output = nn.Sigmoid()(output)
output = self.linear2fc(output)
if self.training:
output=nn.Sigmoid()(output)
return output
@staticmethod
def get_mccnnv2path():
import pkg_resources
return pkg_resources.resource_filename('bob.learn.pytorch', 'models')
def get_model_state_dict(self,pretrained_model_path):
""" The class to load pretrained LightCNN model
Attributes
----------
pretrained_model_path: str
Absolute path to the LightCNN model file
new_state_dict: dict
Dictionary with LightCNN weights
"""
checkpoint = torch.load(pretrained_model_path,map_location=lambda storage,loc:storage)
start_epoch = checkpoint['epoch']
state_dict = checkpoint['state_dict']
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = 'layer_dict.ch_0_'+k[7:] # remove `module.`
new_state_dict[name] = v
# load params
return new_state_dict
...@@ -4,6 +4,7 @@ from .LightCNN import LightCNN9 ...@@ -4,6 +4,7 @@ from .LightCNN import LightCNN9
from .LightCNN import LightCNN29 from .LightCNN import LightCNN29
from .LightCNN import LightCNN29v2 from .LightCNN import LightCNN29v2
from .MCCNN import MCCNN from .MCCNN import MCCNN
from .MCCNNv2 import MCCNNv2
from .DCGAN import DCGAN_generator from .DCGAN import DCGAN_generator
from .DCGAN import DCGAN_discriminator from .DCGAN import DCGAN_discriminator
......
...@@ -68,6 +68,14 @@ def test_architectures(): ...@@ -68,6 +68,14 @@ def test_architectures():
output = net.forward(t) output = net.forward(t)
assert output.shape == torch.Size([1, 1]) assert output.shape == torch.Size([1, 1])
# MCCNNv2
a = numpy.random.rand(1, 4, 128, 128).astype("float32")
t = torch.from_numpy(a)
from ..architectures import MCCNNv2
net = MCCNNv2(num_channels=4)
output = net.forward(t)
assert output.shape == torch.Size([1, 1])
# DCGAN # DCGAN
d = numpy.random.rand(1, 3, 64, 64).astype("float32") d = numpy.random.rand(1, 3, 64, 64).astype("float32")
t = torch.from_numpy(d) t = torch.from_numpy(d)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment