Commit 0312870b authored by Anjith GEORGE's avatar Anjith GEORGE Committed by Anjith GEORGE

Cross modality pre-training for MS PAD

parent 61de6575
import torch
from torch import nn
from torchvision import models
import numpy as np
class DeepMSPAD(nn.Module):
""" Deep multispectral PAD algorithm
The initialization uses `Cross modality pre-training` idea from the following paper:
Wang L, Xiong Y, Wang Z, Qiao Y, Lin D, Tang X, Van Gool L. Temporal segment networks:
Towards good practices for deep action recognition. InEuropean conference on computer
vision 2016 Oct 8 (pp. 20-36). Springer, Cham.
Attributes:
pretrained: bool
if set `True` loads the pretrained vgg16 model.
......@@ -44,19 +52,40 @@ class DeepMSPAD(nn.Module):
features = list(vgg.features.children())
features[0]=nn.Conv2d(num_channels, 64, kernel_size=(3, 3), stride=(1, 1), padding =(1, 1))
# temp layer to extract weights
temp_layer = features[0]
# Implements ``Cross modality pre-training``
# Mean of weight and bias for all filters
bias_values = temp_layer.bias.data.detach().numpy()
mean_weight = np.mean(temp_layer.weight.data.detach().numpy(),axis=1) # for 64 filters
new_weight = np.zeros((64,num_channels,3,3))
for i in range(num_channels):
new_weight[:,i,:,:]=mean_weight
# initialize new layer with required number of channels `num_channels`
features[0] = nn.Conv2d(num_channels, 64, kernel_size=(3, 3), stride=(1, 1), padding =(1, 1))
features[0].weight.data = torch.Tensor(new_weight)
features[0].bias.data = torch.Tensor(bias_values) #check
self.enc = nn.Sequential(*features)
self.linear1=nn.Linear(25088,256)
self.linear1 = nn.Linear(25088,256)
self.relu=nn.ReLU()
self.relu = nn.ReLU()
self.dropout= nn.Dropout(p=0.5)
self.dropout = nn.Dropout(p=0.5)
self.linear2=nn.Linear(256,1)
self.linear2 = nn.Linear(256,1)
self.sigmoid= nn.Sigmoid()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
......@@ -76,16 +105,16 @@ class DeepMSPAD(nn.Module):
enc = self.enc(x)
x=enc.view(-1,25088)
x = enc.view(-1,25088)
x=self.linear1(x)
x = self.linear1(x)
x=self.relu(x)
x = self.relu(x)
x=self.dropout(x)
x = self.dropout(x)
x=self.linear2(x)
x = self.linear2(x)
x=self.sigmoid(x)
x = self.sigmoid(x)
return x
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment