Skip to content
Snippets Groups Projects
Commit b4c2ffa6 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

added lightCNN9 extractor

parent 9554f709
No related branches found
No related tags found
No related merge requests found
Pipeline #26310 failed
import numpy
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from bob.learn.pytorch.architectures import LightCNN9
from bob.bio.base.extractor import Extractor
class LightCNN9Extractor(Extractor):
""" The class implementing the feature extraction of LightCNN9 embeddings.
Attributes
----------
network: :py:class:`torch.nn.Module`
The network architecture
to_tensor: :py:mod:`torchvision.transforms`
The transform from numpy.array to torch.Tensor
norm: :py:mod:`torchvision.transforms`
The transform to normalize the input
"""
def __init__(self, model_file=None, num_classes=79077):
""" Init method
Parameters
----------
model_file: str
The path of the trained network to load
num_classes: int
The number of classes.
"""
Extractor.__init__(self, skip_extractor_training=True)
# model
self.network = LightCNN9(num_classes)
if model_file is None:
# do nothing (used mainly for unit testing)
pass
else:
# pre-trained model was saved using nn.DataParallel ...
cp = torch.load(model_file, map_location='cpu')
# remove 'module.' from the keys
if 'state_dict' in cp:
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in cp['state_dict'].items():
name = k[7:]
new_state_dict[name] = v
self.network.load_state_dict(new_state_dict)
self.network.eval()
# image pre-processing
self.to_tensor = transforms.ToTensor()
self.norm = transforms.Normalize((0.5,), (0.5,))
def __call__(self, image):
""" Extract features from an image
Parameters
----------
image : 3D :py:class:`numpy.ndarray` (floats)
The image to extract the features from. Its size must be 3x128x128
Returns
-------
feature : 2D :py:class:`numpy.ndarray` (floats)
The extracted features as a 1d array of size 320
"""
# torchvision.transforms expect a numpy array of size HxWxC
input_image = numpy.expand_dims(image, axis=2)
input_image = self.to_tensor(input_image)
input_image = self.norm(input_image)
input_image = input_image.unsqueeze(0)
# to be compliant with the loaded model, where weight and biases are torch.FloatTensor
input_image = input_image.float()
_ , features = self.network.forward(Variable(input_image))
features = features.data.numpy().flatten()
return features
from .CNN8 import CNN8Extractor from .CNN8 import CNN8Extractor
from .CasiaNet import CasiaNetExtractor from .CasiaNet import CasiaNetExtractor
from .LightCNN9 import LightCNN9Extractor
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
def __appropriate__(*args): def __appropriate__(*args):
...@@ -19,6 +20,7 @@ def __appropriate__(*args): ...@@ -19,6 +20,7 @@ def __appropriate__(*args):
__appropriate__( __appropriate__(
CNN8Extractor, CNN8Extractor,
CasiaNetExtractor, CasiaNetExtractor,
LightCNN9Extractor,
) )
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
......
...@@ -32,3 +32,16 @@ def test_casianet(): ...@@ -32,3 +32,16 @@ def test_casianet():
data = numpy.random.rand(3, 128, 128).astype("float32") data = numpy.random.rand(3, 128, 128).astype("float32")
output = extractor(data) output = extractor(data)
assert output.shape[0] == 320 assert output.shape[0] == 320
def test_lightcnn9():
""" test for the LightCNN9 architecture
this architecture takes 1x128x128 images as input
output an embedding of dimension 256
"""
from . import LightCNN9Extractor
extractor = LightCNN9Extractor()
# this architecture expects 1x128x128 images
data = numpy.random.rand(1, 128, 128).astype("float32")
output = extractor(data)
assert output.shape[0] == 256
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment