Skip to content
Snippets Groups Projects
Commit a5f31f81 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

added the CNN8 architecture + feature extraction from Xiaojiang

parent 9aa49288
Branches
Tags
No related merge requests found
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
CNN8_CONFIG = [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M']
def make_conv_layers(cfg, input_c = 3):
layers = []
in_channels = input_c
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
layers += [conv2d, nn.ReLU()]
in_channels = v
return nn.Sequential(*layers)
class CNN8(nn.Module):
def __init__(self, num_cls, drop_rate=0.5):
super(CNN8, self).__init__()
self.num_classes = num_cls
self.drop_rate = float(drop_rate)
self.conv = make_conv_layers(CNN8_CONFIG)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(512, self.num_classes)
def forward(self, x):
x = self.conv(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = F.dropout(x, p = self.drop_rate, training=self.training)
return x
from bob.bio.base.extractor import Extractor
class CNN8Extractor(Extractor):
def __init__(self, model_file, num_classes=10575):
Extractor.__init__(self, skip_extractor_training=True)
# model
self.network = CNN8(num_classes)
cp = torch.load(model_file)
if 'state_dict' in cp:
self.network.load_state_dict(cp['state_dict'])
self.network.eval()
# image pre-processing
self.to_tensor = transforms.ToTensor()
self.norm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __call__(self, image):
"""__call__(image) -> feature
Extract features
**Parameters:**
image : 3D :py:class:`numpy.ndarray` (floats)
The image to extract the features from.
**Returns:**
feature : 2D :py:class:`numpy.ndarray` (floats)
The extracted features
"""
input_image = numpy.rollaxis(numpy.rollaxis(image, 2),2)
input_image = self.to_tensor(input_image)
input_image = self.norm(input_image)
input_image = input_image.unsqueeze(0)
features = self.network.forward(Variable(input_image))
feat = feat.data.cpu().numpy().flatten()
features = features.data.numpy().flatten()
print features.shape
return features
from .DRGANLight import DRGANLight from .DRGANLight import DRGANLight
from .DRGANOriginal import DRGANOriginal from .DRGANOriginal import DRGANOriginal
from .CNN8 import CNN8Extractor
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
...@@ -22,6 +23,7 @@ def __appropriate__(*args): ...@@ -22,6 +23,7 @@ def __appropriate__(*args):
__appropriate__( __appropriate__(
DRGANLight, DRGANLight,
DRGANOriginal, DRGANOriginal,
CNN8Extractor,
) )
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment