Commit 5761f2c0 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH

[extractor] added LightCNN based extractors, and corresponding unit tests

parent 922a147a
Pipeline #26559 failed with stage
in 6 minutes and 43 seconds
import numpy
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from bob.learn.pytorch.architectures import LightCNN29
from bob.bio.base.extractor import Extractor
class LightCNN29Extractor(Extractor):
""" The class implementing the feature extraction of LightCNN29 embeddings.
Attributes
----------
network: :py:class:`torch.nn.Module`
The network architecture
to_tensor: :py:mod:`torchvision.transforms`
The transform from numpy.array to torch.Tensor
norm: :py:mod:`torchvision.transforms`
The transform to normalize the input
"""
def __init__(self, model_file=None, num_classes=79077):
""" Init method
Parameters
----------
model_file: str
The path of the trained network to load
num_classes: int
The number of classes.
"""
Extractor.__init__(self, skip_extractor_training=True)
# model
self.network = LightCNN29(num_classes=num_classes)
if model_file is None:
# do nothing (used mainly for unit testing)
pass
else:
cp = torch.load(model_file, map_location='cpu')
# checked if pre-trained model was saved using nn.DataParallel ...
saved_with_nnDataParallel = False
for k, v in cp['state_dict'].items():
if 'module' in k:
saved_with_nnDataParallel = True
break
# if it was, you have to rename the keys of state_dict ... (i.e. remove 'module.')
if saved_with_nnDataParallel:
if 'state_dict' in cp:
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in cp['state_dict'].items():
name = k[7:]
new_state_dict[name] = v
self.network.load_state_dict(new_state_dict)
else:
self.network.load_state_dict(cp['state_dict'])
self.network.eval()
# image pre-processing
self.to_tensor = transforms.ToTensor()
self.norm = transforms.Normalize((0.5,), (0.5,))
def __call__(self, image):
""" Extract features from an image
Parameters
----------
image : 2D :py:class:`numpy.ndarray` (floats)
The grayscale image to extract the features from. Its size must be 128x128
Returns
-------
feature : :py:class:`numpy.ndarray` (floats)
The extracted features as a 1d array of size 320
"""
# torchvision.transforms expect a numpy array of size HxWxC
input_image = numpy.expand_dims(image, axis=2)
input_image = self.to_tensor(input_image)
input_image = self.norm(input_image)
input_image = input_image.unsqueeze(0)
# to be compliant with the loaded model, where weight and biases are torch.FloatTensor
input_image = input_image.float()
_ , features = self.network.forward(Variable(input_image))
features = features.data.numpy().flatten()
return features
import numpy
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from bob.learn.pytorch.architectures import LightCNN29v2
from bob.bio.base.extractor import Extractor
class LightCNN29v2Extractor(Extractor):
""" The class implementing the feature extraction of LightCNN29v2 embeddings.
Attributes
----------
network: :py:class:`torch.nn.Module`
The network architecture
to_tensor: :py:mod:`torchvision.transforms`
The transform from numpy.array to torch.Tensor
norm: :py:mod:`torchvision.transforms`
The transform to normalize the input
"""
def __init__(self, model_file=None, num_classes=79077):
""" Init method
Parameters
----------
model_file: str
The path of the trained network to load
num_classes: int
The number of classes.
"""
Extractor.__init__(self, skip_extractor_training=True)
# model
self.network = LightCNN29v2(num_classes=num_classes)
if model_file is None:
# do nothing (used mainly for unit testing)
pass
else:
cp = torch.load(model_file, map_location='cpu')
# checked if pre-trained model was saved using nn.DataParallel ...
saved_with_nnDataParallel = False
for k, v in cp['state_dict'].items():
if 'module' in k:
saved_with_nnDataParallel = True
break
# if it was, you have to rename the keys of state_dict ... (i.e. remove 'module.')
if saved_with_nnDataParallel:
if 'state_dict' in cp:
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in cp['state_dict'].items():
name = k[7:]
new_state_dict[name] = v
self.network.load_state_dict(new_state_dict)
else:
self.network.load_state_dict(cp['state_dict'])
self.network.eval()
# image pre-processing
self.to_tensor = transforms.ToTensor()
self.norm = transforms.Normalize((0.5,), (0.5,))
def __call__(self, image):
""" Extract features from an image
Parameters
----------
image : 2D :py:class:`numpy.ndarray` (floats)
The grayscale image to extract the features from. Its size must be 128x128
Returns
-------
feature : :py:class:`numpy.ndarray` (floats)
The extracted features as a 1d array of size 320
"""
# torchvision.transforms expect a numpy array of size HxWxC
input_image = numpy.expand_dims(image, axis=2)
input_image = self.to_tensor(input_image)
input_image = self.norm(input_image)
input_image = input_image.unsqueeze(0)
# to be compliant with the loaded model, where weight and biases are torch.FloatTensor
input_image = input_image.float()
_ , features = self.network.forward(Variable(input_image))
features = features.data.numpy().flatten()
return features
......@@ -46,13 +46,13 @@ class LightCNN9Extractor(Extractor):
cp = torch.load(model_file, map_location='cpu')
# checked if pre-trained model was saved using nn.DataParallel ...
saved_with_nnDataParallel = False:
saved_with_nnDataParallel = False
for k, v in cp['state_dict'].items():
if 'module' in k:
saved_with_nnDataParallel = True
break
# it it was, you have to rename the keys of state_dict ... (i.e. remove 'module.')
# if it was, you have to rename the keys of state_dict ... (i.e. remove 'module.')
if saved_with_nnDataParallel:
if 'state_dict' in cp:
from collections import OrderedDict
......
from .LightCNN9 import LightCNN9Extractor
from .LightCNN29 import LightCNN29Extractor
from .LightCNN29v2 import LightCNN29v2Extractor
__all__ = [_ for _ in dir() if not _.startswith('_')]
......@@ -244,21 +244,41 @@ def test_conv_autoencoder():
"""
Test the ConvAutoencoder class.
"""
from bob.learn.pytorch.architectures import ConvAutoencoder
batch = torch.randn(1, 3, 64, 64)
model = ConvAutoencoder()
output = model(batch)
assert batch.shape == output.shape
model_embeddings = ConvAutoencoder(return_latent_embedding = True)
embedding = model_embeddings(batch)
assert list(embedding.shape) == [1, 16, 5, 5]
def test_extractors():
# lightCNN9
from bob.learn.pytorch.extractor.image import LightCNN9Extractor
extractor = LightCNN9Extractor()
# this architecture expects 128x128 grayscale images
data = numpy.random.rand(128, 128).astype("float32")
output = extractor(data)
assert output.shape[0] == 256
# lightCNN29
from bob.learn.pytorch.extractor.image import LightCNN29Extractor
extractor = LightCNN29Extractor()
# this architecture expects 128x128 grayscale images
data = numpy.random.rand(128, 128).astype("float32")
output = extractor(data)
assert output.shape[0] == 256
# lightCNN29v2
from bob.learn.pytorch.extractor.image import LightCNN29v2Extractor
extractor = LightCNN29v2Extractor()
# this architecture expects 128x128 grayscale images
data = numpy.random.rand(128, 128).astype("float32")
output = extractor(data)
assert output.shape[0] == 256
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment