Skip to content
Snippets Groups Projects
Commit a8bdca1b authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

Merge branch 'no_configs' into 'master'

Remove the need for config in extractor and algo

Closes #4

See merge request !5
parents c2290e77 c66d0297
Branches master
No related tags found
1 merge request!5Remove the need for config in extractor and algo
Pipeline #27207 passed
...@@ -23,6 +23,8 @@ from bob.pad.base.utils import convert_list_of_frame_cont_to_array ...@@ -23,6 +23,8 @@ from bob.pad.base.utils import convert_list_of_frame_cont_to_array
import bob.io.base import bob.io.base
import torch
# ============================================================================= # =============================================================================
# Main body : # Main body :
...@@ -34,36 +36,15 @@ class MLPAlgorithm(Algorithm): ...@@ -34,36 +36,15 @@ class MLPAlgorithm(Algorithm):
Attributes Attributes
----------- -----------
config_file: str
Relative name of the config file.
The path should be relative to ``config_group``,
for example: "test_data/mlp_algo_test_config.py".
This file **must** contain at least the following definitions:
Function namely ``transform``, which takes numpy.ndarray as input,
and returns a transformed Tensor. The dimensionality of the output
tensor must match the format expected by the MLP.
A ``Network`` class, defining your network architecture. Note, if your network : object
class is named differently, import it as ``Network``, for example: An instance of an MLP Network to be used for score computation.
``from bob.learn.pytorch.architectures import MyNetwork as Network`` Note: in current algorith the ``forward()`` method of the Network
is used for score computation. For example, if you don't want to use
Optional: ``network_kwargs`` to be used for ``Network`` initialization. a sigmoid in the output of an MLP, set the kwargs accodingly.
For example, if you don't want to use a sigmoid in the output of the
MLP, set the kwargs accodingly. Note: in current algorithm the
``forward()`` method of the ``Network`` is used for feature extraction.
config_group: str
Group/package name containing the configuration file. Usually all
configs should be stored in this folder/place.
For example: "bob.ip.pytorch_extractor".
Both ``config_file`` and ``config_group`` are used to access the
configuration module.
model_file : str model_file : str
A paths to the model file to be used for network initialization. A paths to the model file to be used for ``network`` initialization.
The network structure is defined in the config file.
url : str url : str
URL to download the pretrained models from. URL to download the pretrained models from.
...@@ -91,16 +72,14 @@ class MLPAlgorithm(Algorithm): ...@@ -91,16 +72,14 @@ class MLPAlgorithm(Algorithm):
""" """
def __init__(self, def __init__(self,
config_file, network,
config_group,
model_file = None, model_file = None,
url = None, url = None,
archive_extension = '.tar.gz', archive_extension = '.tar.gz',
frame_level_scores_flag = True, frame_level_scores_flag = True,
mean_std_norm_flag = True): mean_std_norm_flag = True):
super(MLPAlgorithm, self).__init__(config_file = config_file, super(MLPAlgorithm, self).__init__(network = network,
config_group = config_group,
model_file = model_file, model_file = model_file,
url = url, url = url,
archive_extension = archive_extension, archive_extension = archive_extension,
...@@ -109,8 +88,8 @@ class MLPAlgorithm(Algorithm): ...@@ -109,8 +88,8 @@ class MLPAlgorithm(Algorithm):
performs_projection=True, performs_projection=True,
requires_projector_training=True) requires_projector_training=True)
self.config_file = config_file self.transform = lambda x : torch.Tensor(x).unsqueeze(0)
self.config_group = config_group self.network = network
self.model_file = model_file self.model_file = model_file
self.url = url self.url = url
self.archive_extension = archive_extension self.archive_extension = archive_extension
...@@ -221,7 +200,7 @@ class MLPAlgorithm(Algorithm): ...@@ -221,7 +200,7 @@ class MLPAlgorithm(Algorithm):
---------- ----------
feature : FrameContainer or :py:class:`numpy.ndarray` feature : FrameContainer or :py:class:`numpy.ndarray`
Two types of inputs are accepted. Two types of inputs are accepted.
A Frame Container conteining the features of an individual frmaes, A Frame Container containing the features of an individual frmaes,
see ``bob.bio.video.utils.FrameContainer``. see ``bob.bio.video.utils.FrameContainer``.
Or a ND feature array of the size (n_samples x n_features). Or a ND feature array of the size (n_samples x n_features).
...@@ -250,15 +229,11 @@ class MLPAlgorithm(Algorithm): ...@@ -250,15 +229,11 @@ class MLPAlgorithm(Algorithm):
self.features_mean, self.features_mean,
self.features_std) self.features_std)
# kwargs for the transform_and_net_forward function:
function_kwargs = {}
function_kwargs["config_file"] = self.config_file
function_kwargs["config_group"] = self.config_group
function_kwargs["model_file"] = self.model_file
function_kwargs["color_input_flag"] = False
scores = transform_and_net_forward(feature = feature, scores = transform_and_net_forward(feature = feature,
**function_kwargs) transform = self.transform,
network = self.network,
model_file = self.model_file,
color_input_flag = False)
return scores return scores
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
Created on Thu Oct 11 15:32:02 2018
@author: Olegs Nikisins @author: Olegs Nikisins
""" """
...@@ -13,7 +11,6 @@ from bob.bio.base.extractor import Extractor ...@@ -13,7 +11,6 @@ from bob.bio.base.extractor import Extractor
import numpy as np import numpy as np
from bob.ip.pytorch_extractor.utils import reshape_flat_patches from bob.ip.pytorch_extractor.utils import reshape_flat_patches
from bob.ip.pytorch_extractor.utils import combinations
from bob.ip.pytorch_extractor.utils import transform_and_net_forward from bob.ip.pytorch_extractor.utils import transform_and_net_forward
from bob.ip.pytorch_extractor.utils import load_pretrained_model from bob.ip.pytorch_extractor.utils import load_pretrained_model
...@@ -37,35 +34,21 @@ class MultiNetPatchExtractor(Extractor, object): ...@@ -37,35 +34,21 @@ class MultiNetPatchExtractor(Extractor, object):
Attributes Attributes
----------- -----------
config_file: str
Relative name of the config file.
The path should be relative to ``config_group``,
for example: "test_data/multi_net_patch_extractor_test_config.py".
This file **must** contain at least the following definitions:
transform : object
Function namely ``transform``, which is a Compose transformation of Function namely ``transform``, which is a Compose transformation of
torchvision package, to be applied to the input samples. torchvision package, to be applied to the input samples.
A ``Network`` class, defining your network architecture. Note, if your network : object
class is named differently, import it as ``Network``, for example: An instance of the Network to be used for feature extraction.
``from bob.learn.pytorch.architectures import MyNetwork as Network`` Note: in current extractor the ``forward()`` method of the Network
is used for feature extraction. For example, if you want to use the
Optional: ``network_kwargs`` to be used for ``Network`` initialization. latent embeddings of the autoencoder class, initialize the network
For example, if you want to use the latent embeddings of the autoencoder accordingly.
class, set the kwargs accodingly. Note: in current extractor the
``forward()`` method of the ``Network`` is used for feature extraction.
config_group: str
Group/package name containing the configuration file. Usually all
configs should be stored in this folder/place.
For example: "bob.ip.pytorch_extractor".
Both ``config_file`` and ``config_group`` are used to access the
configuration module.
model_file : [str] model_file : [str]
A list of paths to the model files to be used for network initialization. A list of paths to the model files to be used for ``network``
The network structure is defined in the config file. initialization.
patches_num : [int] patches_num : [int]
A list of inices specifying which patches will be selected for A list of inices specifying which patches will be selected for
...@@ -95,7 +78,7 @@ class MultiNetPatchExtractor(Extractor, object): ...@@ -95,7 +78,7 @@ class MultiNetPatchExtractor(Extractor, object):
``urls``. The downloaded models **will be placed to the locations** ``urls``. The downloaded models **will be placed to the locations**
specified in ``model_file`` list. specified in ``model_file`` list.
For example, the pretrained model for the autoencoder pre-trained on For example, a model for an autoencoder pre-trained on
RGB faces of the size (3(channels) x 128 x 128) and fine-tuned RGB faces of the size (3(channels) x 128 x 128) and fine-tuned
on the BW-NIR-D data can be found here: on the BW-NIR-D data can be found here:
["https://www.idiap.ch/software/bob/data/bob/bob.ip.pytorch_extractor/master/" ["https://www.idiap.ch/software/bob/data/bob/bob.ip.pytorch_extractor/master/"
...@@ -110,8 +93,9 @@ class MultiNetPatchExtractor(Extractor, object): ...@@ -110,8 +93,9 @@ class MultiNetPatchExtractor(Extractor, object):
""" """
# ========================================================================= # =========================================================================
def __init__(self, config_file, def __init__(self,
config_group, transform,
network,
model_file, model_file,
patches_num, patches_num,
patch_reshape_parameters = None, patch_reshape_parameters = None,
...@@ -122,8 +106,8 @@ class MultiNetPatchExtractor(Extractor, object): ...@@ -122,8 +106,8 @@ class MultiNetPatchExtractor(Extractor, object):
Init method. Init method.
""" """
super(MultiNetPatchExtractor, self).__init__(config_file = config_file, super(MultiNetPatchExtractor, self).__init__(transform = transform,
config_group = config_group, network = network,
model_file = model_file, model_file = model_file,
patches_num = patches_num, patches_num = patches_num,
patch_reshape_parameters = patch_reshape_parameters, patch_reshape_parameters = patch_reshape_parameters,
...@@ -131,8 +115,8 @@ class MultiNetPatchExtractor(Extractor, object): ...@@ -131,8 +115,8 @@ class MultiNetPatchExtractor(Extractor, object):
urls = urls, urls = urls,
archive_extension = archive_extension) archive_extension = archive_extension)
self.config_file = config_file self.transform = transform
self.config_group = config_group self.network = network
self.model_file = model_file self.model_file = model_file
self.patches_num = patches_num self.patches_num = patches_num
self.patch_reshape_parameters = patch_reshape_parameters self.patch_reshape_parameters = patch_reshape_parameters
...@@ -151,8 +135,6 @@ class MultiNetPatchExtractor(Extractor, object): ...@@ -151,8 +135,6 @@ class MultiNetPatchExtractor(Extractor, object):
patches : 2D :py:class:`numpy.ndarray` patches : 2D :py:class:`numpy.ndarray`
An array containing flattened patches. The dimensions are: An array containing flattened patches. The dimensions are:
``num_patches x len_of_flat_patch`` ``num_patches x len_of_flat_patch``
Optional: the last column of the array can also be a binary mask.
This case is also handled.
Returns Returns
-------- --------
...@@ -160,19 +142,6 @@ class MultiNetPatchExtractor(Extractor, object): ...@@ -160,19 +142,6 @@ class MultiNetPatchExtractor(Extractor, object):
Feature vector. Feature vector.
""" """
# kwargs for the transform_and_net_forward function:
function_kwargs = {}
function_kwargs["config_file"] = self.config_file
function_kwargs["config_group"] = self.config_group
function_kwargs["model_file"] = self.model_file
function_kwargs["color_input_flag"] = self.color_input_flag
# convert all values in the dictionary to the list if not a list already:
function_kwargs = {k:[v] if not isinstance(v, list) else v for (k,v) in function_kwargs.items()}
# compute all possible key-value combinations:
function_kwargs = combinations(function_kwargs) # function_kwargs is now a list with kwargs
# select patches specified by indices: # select patches specified by indices:
patches_selected = [patches[idx] for idx in self.patches_num] # convert to list to make it iterable patches_selected = [patches[idx] for idx in self.patches_num] # convert to list to make it iterable
...@@ -181,11 +150,11 @@ class MultiNetPatchExtractor(Extractor, object): ...@@ -181,11 +150,11 @@ class MultiNetPatchExtractor(Extractor, object):
features_all_patches = [] features_all_patches = []
# make sure the model_file and urls are not None: # make sure the model_file and urls are not None, but lists:
if self.model_file is None: if self.model_file is None:
self.model_file = [self.model_file] self.model_file = [None] * len(self.patches_num)
if self.urls is None: if self.urls is None:
self.urls = [self.urls] self.urls = [None] * len(self.patches_num)
for idx, patch in enumerate(patches_3d): for idx, patch in enumerate(patches_3d):
...@@ -194,16 +163,24 @@ class MultiNetPatchExtractor(Extractor, object): ...@@ -194,16 +163,24 @@ class MultiNetPatchExtractor(Extractor, object):
url = self.urls[self.patches_num[idx]], url = self.urls[self.patches_num[idx]],
archive_extension = self.archive_extension) archive_extension = self.archive_extension)
if len(function_kwargs) == 1: # patches are passed through the same network: if len(self.model_file) == 1: # patches are passed through the same network:
features = transform_and_net_forward(feature = patch, **function_kwargs[0]) features = transform_and_net_forward(feature = patch,
transform = self.transform,
network = self.network,
model_file = self.model_file[0],
color_input_flag = self.color_input_flag)
else: # patches are passed through different networks: else: # patches are passed through different networks:
features = transform_and_net_forward(feature = patch, **function_kwargs[self.patches_num[idx]]) features = transform_and_net_forward(feature = patch,
transform = self.transform,
network = self.network,
model_file = self.model_file[idx],
color_input_flag = self.color_input_flag)
# print ("The model we use for patch {} is:".format(str(idx))) # print ("The model we use for patch {} is:".format(str(idx)))
# print (function_kwargs[self.patches_num(idx)]["model_file"]) # print (self.model_file[idx])
features_all_patches.append(features) features_all_patches.append(features)
...@@ -211,4 +188,3 @@ class MultiNetPatchExtractor(Extractor, object): ...@@ -211,4 +188,3 @@ class MultiNetPatchExtractor(Extractor, object):
return features return features
...@@ -49,6 +49,8 @@ def test_multi_net_patch_extractor(): ...@@ -49,6 +49,8 @@ def test_multi_net_patch_extractor():
""" """
from bob.ip.pytorch_extractor import MultiNetPatchExtractor from bob.ip.pytorch_extractor import MultiNetPatchExtractor
from bob.learn.pytorch.architectures import ConvAutoencoder
from torchvision import transforms
# ========================================================================= # =========================================================================
# prepare the test data: # prepare the test data:
...@@ -62,9 +64,13 @@ def test_multi_net_patch_extractor(): ...@@ -62,9 +64,13 @@ def test_multi_net_patch_extractor():
# ========================================================================= # =========================================================================
# test the extractor: # test the extractor:
CONFIG_FILE = "test_data/multi_net_patch_extractor_test_config.py" # config containing an instance of Composed Transform and a Network class to be used in feature extractor # transform to be applied to input patches:
CONFIG_GROUP = "bob.ip.pytorch_extractor" TRANSFORM = transforms.Compose([transforms.ToTensor(),
# use specific/unique model for each patch. Models pre-trained on CelebA and fine-tuned (3 layers) on BATL: transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# use latent embeddings in the feature extractor:
NETWORK = ConvAutoencoder(return_latent_embedding = True)
MODEL_FILE = None MODEL_FILE = None
...@@ -72,16 +78,16 @@ def test_multi_net_patch_extractor(): ...@@ -72,16 +78,16 @@ def test_multi_net_patch_extractor():
PATCH_RESHAPE_PARAMETERS = [3, 128, 128] # reshape vectorized patches to this dimensions before passing to the Network PATCH_RESHAPE_PARAMETERS = [3, 128, 128] # reshape vectorized patches to this dimensions before passing to the Network
COLOR_INPUT_FLAG = True COLOR_INPUT_FLAG = True
image_extractor = MultiNetPatchExtractor(config_file = CONFIG_FILE, extractor = MultiNetPatchExtractor(transform = TRANSFORM,
config_group = CONFIG_GROUP, network = NETWORK,
model_file = MODEL_FILE, model_file = MODEL_FILE,
patches_num = PATCHES_NUM, patches_num = PATCHES_NUM,
patch_reshape_parameters = PATCH_RESHAPE_PARAMETERS, patch_reshape_parameters = PATCH_RESHAPE_PARAMETERS,
color_input_flag = COLOR_INPUT_FLAG, color_input_flag = COLOR_INPUT_FLAG,
urls = None) urls = None)
# pass through encoder only, compute latent vector: # pass through encoder only, compute latent vector:
latent_vector = image_extractor(patch_flat) latent_vector = extractor(patch_flat)
# test: # test:
assert latent_vector.shape == (1296,) assert latent_vector.shape == (1296,)
...@@ -101,6 +107,8 @@ def test_mlp_algorithm(): ...@@ -101,6 +107,8 @@ def test_mlp_algorithm():
""" """
from bob.ip.pytorch_extractor import MLPAlgorithm from bob.ip.pytorch_extractor import MLPAlgorithm
from bob.learn.pytorch.architectures import TwoLayerMLP
# ========================================================================= # =========================================================================
# prepare the test data / feature vector: # prepare the test data / feature vector:
...@@ -110,16 +118,19 @@ def test_mlp_algorithm(): ...@@ -110,16 +118,19 @@ def test_mlp_algorithm():
# ========================================================================= # =========================================================================
# test the extractor: # test the extractor:
CONFIG_FILE = "test_data/mlp_algo_test_config.py" # config containing an instance of Composed Transform and a Network class to be used in feature extractor # don't use the sigmoid in the output:
CONFIG_GROUP = "bob.ip.pytorch_extractor" NETWORK = TwoLayerMLP(in_features = 1296,
n_hidden_relu = 10,
apply_sigmoid = False)
MODEL_FILE = None MODEL_FILE = None
algorithm = MLPAlgorithm(config_file = CONFIG_FILE, algorithm = MLPAlgorithm(network = NETWORK,
config_group = CONFIG_GROUP,
model_file = MODEL_FILE, model_file = MODEL_FILE,
url = None, url = None,
archive_extension = '.tar.gz', archive_extension = '.tar.gz',
frame_level_scores_flag = True) frame_level_scores_flag = True,
mean_std_norm_flag = True)
# pass through encoder only, compute latent vector: # pass through encoder only, compute latent vector:
score = algorithm.project(features) score = algorithm.project(features)
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
@author: Olegs Nikisins
"""
#==============================================================================
# Import here:
import torch
#==============================================================================
# Define parameters here:
"""
Transformations to be applied to the input 1D numpy arrays (feature vectors).
Only conversion to Tensor and unsqueezing is needed to match the input of
TwoLayerMLP network
"""
def transform(x):
"""
Convert input to Tensor and unsqueeze to match the input of
TwoLayerMLP network.
Arguments
---------
x : numpy array
1D numpy array / feature vector.
Return
------
x_transform : Tensor
Torch tensor, transformed ``x`` to be used as MLP input.
"""
return torch.Tensor(x).unsqueeze(0)
"""
Define the network to be trained as a class, named ``Network``.
Note: Do not change the name of the below class, always import as ``Network``.
"""
from bob.learn.pytorch.architectures import TwoLayerMLP as Network
"""
kwargs to be used for ``Network`` initialization. The name must be ``network_kwargs``.
"""
network_kwargs = {}
network_kwargs['in_features'] = 1296
network_kwargs['n_hidden_relu'] = 10
network_kwargs['apply_sigmoid'] = False # don't use sigmoid to make the scores more even
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
@author: Olegs Nikisins
"""
#==============================================================================
# Import here:
from torchvision import transforms
#==============================================================================
# Define parameters here:
"""
Transformations to be applied sequentially to the input PIL image.
Note: the variable name ``transform`` must be the same in all configuration files.
"""
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
"""
Define the network to be trained as a class, named ``Network``.
Note: Do not change the name of the below class, always import as ``Network``.
"""
from bob.learn.pytorch.architectures import ConvAutoencoder as Network
"""
kwargs to be used for ``Network`` initialization. The name must be ``network_kwargs``.
"""
network_kwargs = {}
network_kwargs['return_latent_embedding'] = True
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
Created on Fri Jan 25 14:29:55 2019
@author: Olegs Nikisins @author: Olegs Nikisins
""" """
...@@ -10,8 +8,6 @@ Created on Fri Jan 25 14:29:55 2019 ...@@ -10,8 +8,6 @@ Created on Fri Jan 25 14:29:55 2019
# Import what is needed here: # Import what is needed here:
import os import os
import importlib
from torchvision import transforms from torchvision import transforms
import numpy as np import numpy as np
...@@ -143,52 +139,25 @@ def apply_transform_as_pil(img_array, transform): ...@@ -143,52 +139,25 @@ def apply_transform_as_pil(img_array, transform):
return pil_img return pil_img
def _init_the_network(config_module): # =============================================================================
""" def _transform(feature, color_input_flag, transform = None):
Initialize the network, given imported configuration module.
Parameters
----------
config_module : object
Module containing Network configuration parameters.
Returns
-------
model : object
An instance of the Network class.
"""
if "network_kwargs" in dir(config_module):
network_kwargs = config_module.network_kwargs
model = config_module.Network(**network_kwargs)
else:
model = config_module.Network()
return model
def _transform(feature, config_module, color_input_flag):
""" """
Apply transformation defined in the configuration module to the input data. Apply transformation to the input data.
Currently two types of the transformation are supported: Currently two types of the transformation are supported:
First, Compose transformation of torchvision package. First, Compose transformation of torchvision package.
Second, custom transformation functions, to be applied to each feature Second, custom transformation functions, to be applied to each feature
vector. For example, mean-std normalization. vector. For example, mean-std normalization, or simple conversion of
input numpy vector to PyTorch tensors.
In both cases define transform() method in the ``config_module``.
Parameters Parameters
---------- ----------
feature : :py:class:`numpy.ndarray` feature : :py:class:`numpy.ndarray`
ND feature array of the size (N_samples x dim1 x dim2 x ...). ND feature array of the size (N_samples x dim1 x dim2 x ...).
config_module : object
Module containing Network configuration parameters.
color_input_flag : bool color_input_flag : bool
If set to ``True``, the input is considered to be a color image of the If set to ``True``, the input is considered to be a color image of the
size ``(3, H, W)``. The tensor to be passed through the net will be size ``(3, H, W)``. The tensor to be passed through the net will be
...@@ -196,17 +165,21 @@ def _transform(feature, config_module, color_input_flag): ...@@ -196,17 +165,21 @@ def _transform(feature, config_module, color_input_flag):
If set to ``False``, the input is considered to be a set of BW images If set to ``False``, the input is considered to be a set of BW images
of the size ``(n_samples, H, W)``. The tensor to be passed through of the size ``(n_samples, H, W)``. The tensor to be passed through
the net will be of the size ``(n_samples, 1, H, W)``. the net will be of the size ``(n_samples, 1, H, W)``.
transform : object
Function namely ``transform`` to be applied to the input samples.
Default: None .
""" """
if "transform" in dir(config_module): if transform is not None:
if isinstance(config_module.transform, transforms.Compose): if isinstance(transform, transforms.Compose):
feature = apply_transform_as_pil(feature, config_module.transform) feature = apply_transform_as_pil(feature, transform)
else: else:
feature = np.stack([config_module.transform(item) for item in feature]) feature = np.stack([transform(item) for item in feature])
if color_input_flag: if color_input_flag:
# convert feature array to Tensor of size (1, 3, H, W) # convert feature array to Tensor of size (1, 3, H, W)
...@@ -221,20 +194,16 @@ def _transform(feature, config_module, color_input_flag): ...@@ -221,20 +194,16 @@ def _transform(feature, config_module, color_input_flag):
# ============================================================================= # =============================================================================
def transform_and_net_forward(feature, def transform_and_net_forward(feature,
config_file, transform,
config_group, network,
model_file, model_file,
color_input_flag = False): color_input_flag = False):
""" """
This function performs the following steps: This function performs the following steps:
1. Import config module. 1. Applies transformation defined in the ``transform`` argument.
2. Applies transformation defined in the config file to the input data.
3. Initializes the Network class with optional kwargs. 2. Pass the transformed data through ``forward()`` method of the Network
4. Pass the transformed data through ``forward()`` method of the Network
class. class.
Parameters Parameters
...@@ -243,17 +212,15 @@ def transform_and_net_forward(feature, ...@@ -243,17 +212,15 @@ def transform_and_net_forward(feature,
feature : :py:class:`numpy.ndarray` feature : :py:class:`numpy.ndarray`
ND feature array of the size (N_samples x dim1 x dim2 x ...). ND feature array of the size (N_samples x dim1 x dim2 x ...).
config_file: str transform : object
Relative name of the config file defining the network, training data, Function namely ``transform`` to be applied to the input samples.
and training parameters. The path should be relative to
``config_group``, for example: "autoencoder/netN.py".
config_group : str network : object
Group/package name containing the configuration file. Usually all An instance of the Network to be used for feature extraction.
configs should be stored in this folder/place. Note: in current extractor the ``forward()`` method of the Network
For example: "bob.pad.face.config.pytorch". is used for feature extraction. For example, if you want to use the
Both ``config_file`` and ``config_group`` are used to access the latent embeddings of the autoencoder class, initialize the network
configuration module. accordingly.
model_file : str model_file : str
A path to the model file to be used for network initialization. A path to the model file to be used for network initialization.
...@@ -274,38 +241,25 @@ def transform_and_net_forward(feature, ...@@ -274,38 +241,25 @@ def transform_and_net_forward(feature,
Output of the network per input sample/frame. Output of the network per input sample/frame.
""" """
# =========================================================================
# Create relative module name given path:
relative_mod_name = '.' + os.path.splitext(config_file)[0].replace(os.path.sep, '.')
# import configuration module:
config_module = importlib.import_module(relative_mod_name, config_group)
# ========================================================================= # =========================================================================
# transform the input image or feature vector: # transform the input image or feature vector:
feature_tensor = _transform(feature, config_module, color_input_flag) feature_tensor = _transform(feature, color_input_flag, transform)
# ========================================================================= # =========================================================================
# Initialize the model
local_model = _init_the_network(config_module)
# Load the pre-trained model into the network # Load the pre-trained model into the network
if model_file is not None: if model_file is not None:
model_state = torch.load(model_file, map_location=lambda storage,loc:storage) model_state = torch.load(model_file, map_location=lambda storage,loc:storage)
# Initialize the state of the model: # Initialize the state of the model:
local_model.load_state_dict(model_state) network.load_state_dict(model_state)
# Model is used for evaluation only: # Model is used for evaluation only:
local_model.train(False) network.train(False)
# ========================================================================= # =========================================================================
# Pass the transformed feature vector through the network: # Pass the transformed feature vector through the network:
output_tensor = local_model.forward(Variable(feature_tensor)) output_tensor = network.forward(Variable(feature_tensor))
net_output = output_tensor.data.numpy().squeeze() net_output = output_tensor.data.numpy().squeeze()
...@@ -314,6 +268,7 @@ def transform_and_net_forward(feature, ...@@ -314,6 +268,7 @@ def transform_and_net_forward(feature,
return net_output.astype(np.float) return net_output.astype(np.float)
# =============================================================================
def load_pretrained_model(model_path, url, archive_extension = '.tar.gz'): def load_pretrained_model(model_path, url, archive_extension = '.tar.gz'):
""" """
Loads the model from the given ``url``, if the model specified in the Loads the model from the given ``url``, if the model specified in the
...@@ -322,6 +277,7 @@ def load_pretrained_model(model_path, url, archive_extension = '.tar.gz'): ...@@ -322,6 +277,7 @@ def load_pretrained_model(model_path, url, archive_extension = '.tar.gz'):
Arguments Arguments
--------- ---------
model_path : str model_path : str
Absolute file name pointing to the model. Absolute file name pointing to the model.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment