Skip to content
Snippets Groups Projects
Commit 9ad938ef authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

Merge branch 'ae_extractor' into 'master'

Added MultiNetPatchClassifier extractor and utils, temp fix in LightCNN9

See merge request !3
parents 60633779 6c352a8d
Branches
Tags
1 merge request!3Added MultiNetPatchClassifier extractor and utils, temp fix in LightCNN9
Pipeline #26600 passed
include README.rst buildout.cfg develop.cfg COPYING version.txt requirements.txt
recursive-include doc *.py *.rst
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 11 15:32:02 2018
@author: Olegs Nikisins
"""
# =============================================================================
# Import what is needed here:
from bob.bio.base.extractor import Extractor
import numpy as np
from bob.ip.pytorch_extractor.utils import reshape_flat_patches
from bob.ip.pytorch_extractor.utils import combinations
from bob.ip.pytorch_extractor.utils import transform_and_net_forward
from bob.ip.pytorch_extractor.utils import load_pretrained_model
# =============================================================================
# Main body:
class MultiNetPatchExtractor(Extractor, object):
"""
This class is designed to pass a set of patches through a possibly multiple
networks and compute a feature vector combining outputs of all networks.
The functional work-flow is the following:
First, an array of **flattened** input patches is converted to a list
of patches with original dimensions (2D or 3D arrays).
Second, each patch is passed through an individual network, for example
an auto-encoder pre-trained for each patch type (left eye, for example).
Third, outputs of all networks are concatenated into a single feature
vector.
Attributes
-----------
config_file: str
Relative name of the config file.
The path should be relative to ``config_group``,
for example: "autoencoder/net1_batl_3_layers_partial.py".
This file **must** contain at least the following definitions:
Function namely ``transform``, which is a Compose transformation of
torchvision package, to be applied to the input samples.
A ``Network`` class, defining your network architecture. Note, if your
class is named differently, import it as ``Network``, for example:
``from bob.learn.pytorch.architectures import MyNetwork as Network``
Optional: ``network_kwargs`` to be used for ``Network`` initialization.
For example, if you want to use the latent embeddings of the autoencoder
class, set the kwargs accodingly. Note: in current extractor the
``forward()`` method of the ``Network`` is used for feature extraction.
config_group: str
Group/package name containing the configuration file. Usually all
configs should be stored in this folder/place.
For example: "bob.learn.pytorch.config".
Both ``config_file`` and ``config_group`` are used to access the
configuration module.
model_file : [str]
A list of paths to the model files to be used for network initialization.
The network structure is defined in the config file.
patches_num : [int]
A list of inices specifying which patches will be selected for
processing/feature vector extraction.
patch_reshape_parameters : [int] or None
The parameters to be used for patch reshaping. The loaded patch is
vectorized. Example:
``patch_reshape_parameters = [4, 8, 8]``, then the patch of the
size (256,) will be reshaped to (4,8,8) dimensions. Only 2D and 3D
patches are supported.
Default: None.
color_input_flag : bool
If set to ``True``, the input is considered to be a color image of the
size ``(3, H, W)``. The tensor to be passed through the net will be
of the size ``(1, 3, H, W)``.
If set to ``False``, the input is considered to be a set of BW images
of the size ``(n_samples, H, W)``. The tensor to be passed through
the net will be of the size ``(n_samples, 1, H, W)``.
Default: ``False``.
urls : [str]
List of URLs to download the pretrained models from.
If models are not available in the locations specified in the
``model_file`` list, the system will try to download them from
``urls``. The downloaded models **will be placed to the locations**
specified in ``model_file`` list.
For example, the pretrained model for the autoencoder pre-trained on
RGB faces of the size (3(channels) x 128 x 128) and fine-tuned
on the BW-NIR-D data can be found here:
["https://www.idiap.ch/software/bob/data/bob/bob.ip.pytorch_extractor/master/"
"conv_ae_model_pretrain_celeba_tune_batl_full_face.pth.tar.gz"]
Default: None
archive_extension : str
Extension of the archived files to download from above ``urls``.
Default: '.tar.gz'
"""
# =========================================================================
def __init__(self, config_file,
config_group,
model_file,
patches_num,
patch_reshape_parameters = None,
color_input_flag = False,
urls = None,
archive_extension = '.tar.gz'):
"""
Init method.
"""
super(MultiNetPatchExtractor, self).__init__(config_file = config_file,
config_group = config_group,
model_file = model_file,
patches_num = patches_num,
patch_reshape_parameters = patch_reshape_parameters,
color_input_flag = color_input_flag,
urls = urls,
archive_extension = archive_extension)
self.config_file = config_file
self.config_group = config_group
self.model_file = model_file
self.patches_num = patches_num
self.patch_reshape_parameters = patch_reshape_parameters
self.color_input_flag = color_input_flag
self.urls = urls
self.archive_extension = archive_extension
# =========================================================================
def __call__(self, patches):
"""
Extract features combining outputs of multiple networks.
Parameters
-----------
patches : 2D :py:class:`numpy.ndarray`
An array containing flattened patches. The dimensions are:
``num_patches x len_of_flat_patch``
Optional: the last column of the array can also be a binary mask.
This case is also handled.
Returns
--------
features : :py:class:`numpy.ndarray`
Feature vector.
"""
# kwargs for the transform_and_net_forward function:
function_kwargs = {}
function_kwargs["config_file"] = self.config_file
function_kwargs["config_group"] = self.config_group
function_kwargs["model_file"] = self.model_file
function_kwargs["color_input_flag"] = self.color_input_flag
# convert all values in the dictionary to the list if not a list already:
function_kwargs = {k:[v] if not isinstance(v, list) else v for (k,v) in function_kwargs.items()}
# compute all possible key-value combinations:
function_kwargs = combinations(function_kwargs) # function_kwargs is now a list with kwargs
# select patches specified by indices:
patches_selected = [patches[idx] for idx in self.patches_num] # convert to list to make it iterable
# convert to original dimensions:
patches_3d = reshape_flat_patches(patches_selected, self.patch_reshape_parameters)
features_all_patches = []
# make sure the model_file and urls are not None:
if self.model_file is None:
self.model_file = [self.model_file]
if self.urls is None:
self.urls = [self.urls]
for idx, patch in enumerate(patches_3d):
# try to load the model if not available, do nothing if available:
load_pretrained_model(model_path = self.model_file[self.patches_num[idx]],
url = self.urls[self.patches_num[idx]],
archive_extension = self.archive_extension)
if len(function_kwargs) == 1: # patches are passed through the same network:
features = transform_and_net_forward(feature = patch, **function_kwargs[0])
else: # patches are passed through different networks:
features = transform_and_net_forward(feature = patch, **function_kwargs[self.patches_num[idx]])
# print ("The model we use for patch {} is:".format(str(idx)))
# print (function_kwargs[self.patches_num(idx)]["model_file"])
features_all_patches.append(features)
features = np.hstack(features_all_patches)
return features
from .CNN8 import CNN8Extractor
from .CasiaNet import CasiaNetExtractor
from .LightCNN9 import LightCNN9Extractor
from .MultiNetPatchExtractor import MultiNetPatchExtractor
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
......@@ -21,6 +22,7 @@ __appropriate__(
CNN8Extractor,
CasiaNetExtractor,
LightCNN9Extractor,
MultiNetPatchExtractor,
)
# gets sphinx autodoc done right - don't remove it
......
......@@ -45,3 +45,57 @@ def test_lightcnn9():
data = numpy.random.rand(128, 128).astype("float32")
output = extractor(data)
assert output.shape[0] == 256
def test_multi_net_patch_extractor():
"""
Test the MultiNetPatchExtractor extractor class.
"""
from bob.ip.pytorch_extractor import MultiNetPatchExtractor
# =========================================================================
# prepare the test data:
patch_2d = numpy.repeat(numpy.expand_dims(numpy.sin(numpy.arange(0,12.8,0.1)), axis=0), 128, axis=0)
patch = numpy.uint8((numpy.stack([patch_2d, patch_2d.transpose(), -patch_2d])+1)*255/2.)
# flatten the 3D test patch:
patch_flat = numpy.expand_dims(patch.flatten(), axis=0)
# =========================================================================
# test the extractor:
CONFIG_FILE = "test_data/net1_test_config.py" # config containing an instance of Composed Transform and a Network class to be used in feature extractor
CONFIG_GROUP = "bob.ip.pytorch_extractor"
# use specific/unique model for each patch. Models pre-trained on CelebA and fine-tuned (3 layers) on BATL:
MODEL_FILE = None
PATCHES_NUM = [0] # patches to be used in the feature extraction process
PATCH_RESHAPE_PARAMETERS = [3, 128, 128] # reshape vectorized patches to this dimensions before passing to the Network
COLOR_INPUT_FLAG = True
image_extractor = MultiNetPatchExtractor(config_file = CONFIG_FILE,
config_group = CONFIG_GROUP,
model_file = MODEL_FILE,
patches_num = PATCHES_NUM,
patch_reshape_parameters = PATCH_RESHAPE_PARAMETERS,
color_input_flag = COLOR_INPUT_FLAG,
urls = None)
# pass through encoder only, compute latent vector:
latent_vector = image_extractor(patch_flat)
# test:
assert latent_vector.shape == (1296,)
# # for visualization/debugging only:
# import matplotlib.pyplot as plt
# import bob.io.image
#
# plt.figure()
# plt.imshow(bob.io.image.to_matplotlib(patch))
# plt.show()
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
@author: Olegs Nikisins
"""
#==============================================================================
# Import here:
from torchvision import transforms
#from bob.pad.face.database import CELEBAPadDatabase
from torch import nn
#==============================================================================
# Define parameters here:
"""
Transformations to be applied sequentially to the input PIL image.
Note: the variable name ``transform`` must be the same in all configuration files.
"""
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
"""
Define the network to be trained as a class, named ``Network``.
Note: Do not change the name of the below class, always import as ``Network``.
"""
from bob.learn.pytorch.architectures import ConvAutoencoder as Network
"""
kwargs to be used for ``Network`` initialization. The name must be ``network_kwargs``.
"""
network_kwargs = {}
network_kwargs['return_latent_embedding'] = True
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jan 25 14:29:55 2019
@author: Olegs Nikisins
"""
# =============================================================================
# Import what is needed here:
import os
import importlib
from torchvision import transforms
import numpy as np
import PIL
import torch
from torch.autograd import Variable
import itertools as it
import bob.io.base
from bob.extension.download import download_and_unzip
from bob.extension.download import download_file
# =============================================================================
def reshape_flat_patches(patches, patch_reshape_parameters = None):
"""
Reshape a set of flattened patches into original dimensions, 2D or 3D
Parameters
----------
patches : 2D :py:class:`numpy.ndarray`
An array containing flattened patches. The dimensions are:
``num_patches x len_of_flat_patch``
patch_reshape_parameters : [int] or None
The parameters to be used for patch reshaping. The loaded patch is
vectorized. Example:
``patch_reshape_parameters = [4, 8, 8]``, then the patch of the
size (256,) will be reshaped to (4,8,8) dimensions. Only 2D and 3D
patches are supported.
Default: None.
Returns
-------
patches_3d : [2D or 3D :py:class:`numpy.ndarray`]
A list of patches converted to the original dimensions.
"""
patches_3d = []
for patch in patches:
if patch_reshape_parameters is not None:
# The dimensionality of the reshaped patch:
new_shape = [np.int(len(patch)/(patch_reshape_parameters[-2]*patch_reshape_parameters[-1]))] + list(patch_reshape_parameters[-2:])
patch = np.squeeze(patch.reshape(new_shape))
patches_3d.append(patch)
return patches_3d
# =============================================================================
def combinations( input_dict ):
"""
Obtain all possible key-value combinations in the input dictionary.
Parameters
----------
input_dict : dict
An input dictionary.
Returns
-------
combinations : [dict]
List of dictionaries containing the combinations.
"""
varNames = sorted( input_dict )
combinations = [ dict( zip( varNames, prod ) ) for prod in it.product( *( input_dict[ varName ] for varName in varNames ) ) ]
return combinations
# =============================================================================
def apply_transform_as_pil(img_array, transform):
"""
Apply composed transformation to the input bw or rgb image / numpy array.
Input image is in the Bob format. Before the transformation the input
image is transformed to the PIL image.
Parameters
----------
img_array : 2D or 3D :py:class:`numpy.ndarray`
An input image / array. stored in Bob format for RGB images.
transform : torchvision.transforms.transforms.Compose
Composed transfromation to be applied to the input image.
Returns
-------
features : Tensor
Transformed image.
"""
if isinstance(transform, transforms.Compose): # if an instance of torchvision composed transformation
if len(img_array.shape) == 3: # for color images
img_array_tr = np.swapaxes(img_array, 1, 2)
img_array_tr = np.swapaxes(img_array_tr, 0, 2)
pil_img = PIL.Image.fromarray( img_array_tr ) # convert to PIL from array of size (H x W x 3)
else: # for gray-scale images
pil_img = PIL.Image.fromarray( img_array, 'L' ) # convert to PIL from array of size (H x W)
if transform is not None:
pil_img = transform(pil_img)
return pil_img
def _init_the_network(config_module):
"""
Initialize the network, given imported configuration module.
Parameters
----------
config_module : object
Module containing Network configuration parameters.
Returns
-------
model : object
An instance of the Network class.
"""
if "network_kwargs" in dir(config_module):
network_kwargs = config_module.network_kwargs
model = config_module.Network(**network_kwargs)
else:
model = config_module.Network()
return model
def _transform(feature, config_module, color_input_flag):
"""
Apply transformation defined in the configuration module to the input data.
Currently two types of the transformation are supported:
First, Compose transformation of torchvision package.
Second, custom transformation functions, to be applied to each feature
vector. For example, mean-std normalization.
In both cases define transform() method in the ``config_module``.
Parameters
----------
feature : :py:class:`numpy.ndarray`
ND feature array of the size (N_samples x dim1 x dim2 x ...).
config_module : object
Module containing Network configuration parameters.
color_input_flag : bool
If set to ``True``, the input is considered to be a color image of the
size ``(3, H, W)``. The tensor to be passed through the net will be
of the size ``(1, 3, H, W)``.
If set to ``False``, the input is considered to be a set of BW images
of the size ``(n_samples, H, W)``. The tensor to be passed through
the net will be of the size ``(n_samples, 1, H, W)``.
"""
if "transform" in dir(config_module):
if isinstance(config_module.transform, transforms.Compose):
feature = apply_transform_as_pil(feature, config_module.transform)
else:
feature = np.stack([config_module.transform(item) for item in feature])
if color_input_flag:
# convert feature array to Tensor of size (1, 3, H, W)
feature_tensor = torch.Tensor(feature).unsqueeze(0)
else:
# convert feature array to Tensor of size (n_samples, 1, H, W)
feature_tensor = torch.Tensor(feature).unsqueeze(1)
return feature_tensor
# =============================================================================
def transform_and_net_forward(feature,
config_file,
config_group,
model_file,
color_input_flag = False):
"""
This function performs the following steps:
1. Import config module.
2. Applies transformation defined in the config file to the input data.
3. Initializes the Network class with optional kwargs.
4. Pass the transformed data through ``forward()`` method of the Network
class.
Parameters
----------
feature : :py:class:`numpy.ndarray`
ND feature array of the size (N_samples x dim1 x dim2 x ...).
config_file: str
Relative name of the config file defining the network, training data,
and training parameters. The path should be relative to
``config_group``, for example: "autoencoder/netN.py".
config_group : str
Group/package name containing the configuration file. Usually all
configs should be stored in this folder/place.
For example: "bob.pad.face.config.pytorch".
Both ``config_file`` and ``config_group`` are used to access the
configuration module.
model_file : str
A path to the model file to be used for network initialization.
The network structure is defined in the config file.
color_input_flag : bool
If set to ``True``, the input is considered to be a color image of the
size ``(3, H, W)``. The tensor to be passed through the net will be
of the size ``(1, 3, H, W)``.
If set to ``False``, the input is considered to be a set of BW images
of the size ``(n_samples, H, W)``. The tensor to be passed through
the net will be of the size ``(n_samples, 1, H, W)``.
Returns
-------
net_output : [float]
Output of the network per input sample/frame.
"""
# =========================================================================
# Create relative module name given path:
relative_mod_name = '.' + os.path.splitext(config_file)[0].replace(os.path.sep, '.')
# import configuration module:
config_module = importlib.import_module(relative_mod_name, config_group)
# =========================================================================
# transform the input image or feature vector:
feature_tensor = _transform(feature, config_module, color_input_flag)
# =========================================================================
# Initialize the model
local_model = _init_the_network(config_module)
# Load the pre-trained model into the network
if model_file is not None:
model_state = torch.load(model_file, map_location=lambda storage,loc:storage)
# Initialize the state of the model:
local_model.load_state_dict(model_state)
# Model is used for evaluation only:
local_model.train(False)
# =========================================================================
# Pass the transformed feature vector through the network:
output_tensor = local_model.forward(Variable(feature_tensor))
net_output = output_tensor.data.numpy().squeeze()
net_output = net_output.flatten()
return net_output.astype(np.float)
def load_pretrained_model(model_path, url, archive_extension = '.tar.gz'):
"""
Loads the model from the given ``url``, if the model specified in the
``model_path`` is unavailable. The model will be saved to the location
defined in the ``model_path`` string, and will have the same filename.
Arguments
---------
model_path : str
Absolute file name pointing to the model.
url : str
URL to download the model from.
archive_extension : str
Extension of the archived file.
If ``archive_extension`` is ``None`` the model is downloaded without
de-archivation.
Default: '.tar.gz'
"""
if model_path is not None and url is not None:
if not os.path.exists(model_path):
print ("The model is not available at " + model_path)
print ("Trying to load it from " + url)
bob.io.base.create_directories_safe(os.path.split(model_path)[0])
if archive_extension is None: # if file is not an archive, download it
download_file(url, model_path)
else: # download and de-archivate
archive_file = os.path.splitext(model_path)[0] + archive_extension
download_and_unzip(url, archive_file)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment