Commit 9ad938ef authored by Guillaume HEUSCH's avatar Guillaume HEUSCH

Merge branch 'ae_extractor' into 'master'

Added MultiNetPatchClassifier extractor and utils, temp fix in LightCNN9

See merge request !3
parents 60633779 6c352a8d
Pipeline #26600 passed with stages
in 8 minutes and 18 seconds
include README.rst buildout.cfg develop.cfg COPYING version.txt requirements.txt
recursive-include doc *.py *.rst
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
Created on Thu Oct 11 15:32:02 2018
@author: Olegs Nikisins
# =============================================================================
# Import what is needed here:
from import Extractor
import numpy as np
from bob.ip.pytorch_extractor.utils import reshape_flat_patches
from bob.ip.pytorch_extractor.utils import combinations
from bob.ip.pytorch_extractor.utils import transform_and_net_forward
from bob.ip.pytorch_extractor.utils import load_pretrained_model
# =============================================================================
# Main body:
class MultiNetPatchExtractor(Extractor, object):
This class is designed to pass a set of patches through a possibly multiple
networks and compute a feature vector combining outputs of all networks.
The functional work-flow is the following:
First, an array of **flattened** input patches is converted to a list
of patches with original dimensions (2D or 3D arrays).
Second, each patch is passed through an individual network, for example
an auto-encoder pre-trained for each patch type (left eye, for example).
Third, outputs of all networks are concatenated into a single feature
config_file: str
Relative name of the config file.
The path should be relative to ``config_group``,
for example: "autoencoder/".
This file **must** contain at least the following definitions:
Function namely ``transform``, which is a Compose transformation of
torchvision package, to be applied to the input samples.
A ``Network`` class, defining your network architecture. Note, if your
class is named differently, import it as ``Network``, for example:
``from bob.learn.pytorch.architectures import MyNetwork as Network``
Optional: ``network_kwargs`` to be used for ``Network`` initialization.
For example, if you want to use the latent embeddings of the autoencoder
class, set the kwargs accodingly. Note: in current extractor the
``forward()`` method of the ``Network`` is used for feature extraction.
config_group: str
Group/package name containing the configuration file. Usually all
configs should be stored in this folder/place.
For example: "bob.learn.pytorch.config".
Both ``config_file`` and ``config_group`` are used to access the
configuration module.
model_file : [str]
A list of paths to the model files to be used for network initialization.
The network structure is defined in the config file.
patches_num : [int]
A list of inices specifying which patches will be selected for
processing/feature vector extraction.
patch_reshape_parameters : [int] or None
The parameters to be used for patch reshaping. The loaded patch is
vectorized. Example:
``patch_reshape_parameters = [4, 8, 8]``, then the patch of the
size (256,) will be reshaped to (4,8,8) dimensions. Only 2D and 3D
patches are supported.
Default: None.
color_input_flag : bool
If set to ``True``, the input is considered to be a color image of the
size ``(3, H, W)``. The tensor to be passed through the net will be
of the size ``(1, 3, H, W)``.
If set to ``False``, the input is considered to be a set of BW images
of the size ``(n_samples, H, W)``. The tensor to be passed through
the net will be of the size ``(n_samples, 1, H, W)``.
Default: ``False``.
urls : [str]
List of URLs to download the pretrained models from.
If models are not available in the locations specified in the
``model_file`` list, the system will try to download them from
``urls``. The downloaded models **will be placed to the locations**
specified in ``model_file`` list.
For example, the pretrained model for the autoencoder pre-trained on
RGB faces of the size (3(channels) x 128 x 128) and fine-tuned
on the BW-NIR-D data can be found here:
Default: None
archive_extension : str
Extension of the archived files to download from above ``urls``.
Default: '.tar.gz'
# =========================================================================
def __init__(self, config_file,
patch_reshape_parameters = None,
color_input_flag = False,
urls = None,
archive_extension = '.tar.gz'):
Init method.
super(MultiNetPatchExtractor, self).__init__(config_file = config_file,
config_group = config_group,
model_file = model_file,
patches_num = patches_num,
patch_reshape_parameters = patch_reshape_parameters,
color_input_flag = color_input_flag,
urls = urls,
archive_extension = archive_extension)
self.config_file = config_file
self.config_group = config_group
self.model_file = model_file
self.patches_num = patches_num
self.patch_reshape_parameters = patch_reshape_parameters
self.color_input_flag = color_input_flag
self.urls = urls
self.archive_extension = archive_extension
# =========================================================================
def __call__(self, patches):
Extract features combining outputs of multiple networks.
patches : 2D :py:class:`numpy.ndarray`
An array containing flattened patches. The dimensions are:
``num_patches x len_of_flat_patch``
Optional: the last column of the array can also be a binary mask.
This case is also handled.
features : :py:class:`numpy.ndarray`
Feature vector.
# kwargs for the transform_and_net_forward function:
function_kwargs = {}
function_kwargs["config_file"] = self.config_file
function_kwargs["config_group"] = self.config_group
function_kwargs["model_file"] = self.model_file
function_kwargs["color_input_flag"] = self.color_input_flag
# convert all values in the dictionary to the list if not a list already:
function_kwargs = {k:[v] if not isinstance(v, list) else v for (k,v) in function_kwargs.items()}
# compute all possible key-value combinations:
function_kwargs = combinations(function_kwargs) # function_kwargs is now a list with kwargs
# select patches specified by indices:
patches_selected = [patches[idx] for idx in self.patches_num] # convert to list to make it iterable
# convert to original dimensions:
patches_3d = reshape_flat_patches(patches_selected, self.patch_reshape_parameters)
features_all_patches = []
# make sure the model_file and urls are not None:
if self.model_file is None:
self.model_file = [self.model_file]
if self.urls is None:
self.urls = [self.urls]
for idx, patch in enumerate(patches_3d):
# try to load the model if not available, do nothing if available:
load_pretrained_model(model_path = self.model_file[self.patches_num[idx]],
url = self.urls[self.patches_num[idx]],
archive_extension = self.archive_extension)
if len(function_kwargs) == 1: # patches are passed through the same network:
features = transform_and_net_forward(feature = patch, **function_kwargs[0])
else: # patches are passed through different networks:
features = transform_and_net_forward(feature = patch, **function_kwargs[self.patches_num[idx]])
# print ("The model we use for patch {} is:".format(str(idx)))
# print (function_kwargs[self.patches_num(idx)]["model_file"])
features = np.hstack(features_all_patches)
return features
from .CNN8 import CNN8Extractor
from .CasiaNet import CasiaNetExtractor
from .LightCNN9 import LightCNN9Extractor
from .MultiNetPatchExtractor import MultiNetPatchExtractor
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
......@@ -21,6 +22,7 @@ __appropriate__(
# gets sphinx autodoc done right - don't remove it
......@@ -45,3 +45,57 @@ def test_lightcnn9():
data = numpy.random.rand(128, 128).astype("float32")
output = extractor(data)
assert output.shape[0] == 256
def test_multi_net_patch_extractor():
Test the MultiNetPatchExtractor extractor class.
from bob.ip.pytorch_extractor import MultiNetPatchExtractor
# =========================================================================
# prepare the test data:
patch_2d = numpy.repeat(numpy.expand_dims(numpy.sin(numpy.arange(0,12.8,0.1)), axis=0), 128, axis=0)
patch = numpy.uint8((numpy.stack([patch_2d, patch_2d.transpose(), -patch_2d])+1)*255/2.)
# flatten the 3D test patch:
patch_flat = numpy.expand_dims(patch.flatten(), axis=0)
# =========================================================================
# test the extractor:
CONFIG_FILE = "test_data/" # config containing an instance of Composed Transform and a Network class to be used in feature extractor
CONFIG_GROUP = "bob.ip.pytorch_extractor"
# use specific/unique model for each patch. Models pre-trained on CelebA and fine-tuned (3 layers) on BATL:
PATCHES_NUM = [0] # patches to be used in the feature extraction process
PATCH_RESHAPE_PARAMETERS = [3, 128, 128] # reshape vectorized patches to this dimensions before passing to the Network
image_extractor = MultiNetPatchExtractor(config_file = CONFIG_FILE,
config_group = CONFIG_GROUP,
model_file = MODEL_FILE,
patches_num = PATCHES_NUM,
patch_reshape_parameters = PATCH_RESHAPE_PARAMETERS,
color_input_flag = COLOR_INPUT_FLAG,
urls = None)
# pass through encoder only, compute latent vector:
latent_vector = image_extractor(patch_flat)
# test:
assert latent_vector.shape == (1296,)
# # for visualization/debugging only:
# import matplotlib.pyplot as plt
# import
# plt.figure()
# plt.imshow(
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
@author: Olegs Nikisins
# Import here:
from torchvision import transforms
#from bob.pad.face.database import CELEBAPadDatabase
from torch import nn
# Define parameters here:
Transformations to be applied sequentially to the input PIL image.
Note: the variable name ``transform`` must be the same in all configuration files.
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
Define the network to be trained as a class, named ``Network``.
Note: Do not change the name of the below class, always import as ``Network``.
from bob.learn.pytorch.architectures import ConvAutoencoder as Network
kwargs to be used for ``Network`` initialization. The name must be ``network_kwargs``.
network_kwargs = {}
network_kwargs['return_latent_embedding'] = True
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment