diff --git a/bob/ip/pytorch_extractor/MultiNetPatchExtractor.py b/bob/ip/pytorch_extractor/MultiNetPatchExtractor.py
index 707ad8c919f76659771d11fa3d5c5f483eff12c2..e72516badc9cb6096aad5aa5c8b8c0bcbb0fe382 100644
--- a/bob/ip/pytorch_extractor/MultiNetPatchExtractor.py
+++ b/bob/ip/pytorch_extractor/MultiNetPatchExtractor.py
@@ -41,6 +41,10 @@ class MultiNetPatchExtractor(Extractor, object):
The path should be relative to ``config_group``,
for example: "autoencoder/net1_batl_3_layers_partial.py".
+
+ ADD THE DOC on what should be in the config!!!!!!!!!!!!!!!!!!!!!!
+
+
config_group: str
Group/package name containing the configuration file. Usually all
configs should be stored in this folder/place.
@@ -48,21 +52,10 @@ class MultiNetPatchExtractor(Extractor, object):
Both ``config_file`` and ``config_group`` are used to access the
configuration module.
- model_file : str
- A path to the model file to be used for network initialization.
+ model_file : [str]
+ A list of paths to the model files to be used for network initialization.
The network structure is defined in the config file.
- function_kwargs : dict
-
- UPDATE THIS!!!!
-
-
- Key-word arguments for the function defined by ``function_name``.
- Note, that you can also specify one of the values in the dictionary
- as a list containing multiple elements. Then, ``function_kwargs`` will
- be different, for each patch you apply function, defined by
- ``function_name``, to. See the ``__call__`` for more details.
-
patches_num : [int]
A list of inices specifying which patches will be selected for
processing/feature vector extraction.
@@ -74,15 +67,24 @@ class MultiNetPatchExtractor(Extractor, object):
size (256,) will be reshaped to (4,8,8) dimensions. Only 2D and 3D
patches are supported.
Default: None.
+
+ color_input_flag : bool
+ If set to ``True``, the input is considered to be a color image of the
+ size ``(3, H, W)``. The tensor to be passed through the net will be
+ of the size ``(1, 3, H, W)``.
+ If set to ``False``, the input is considered to be a set of BW images
+ of the size ``(n_samples, H, W)``. The tensor to be passed through
+ the net will be of the size ``(n_samples, 1, H, W)``.
+ Default: ``False``.
"""
# =========================================================================
def __init__(self, config_file,
config_group,
model_file,
- function_kwargs,
patches_num,
- patch_reshape_parameters = None):
+ patch_reshape_parameters = None,
+ color_input_flag = False):
"""
Init method.
"""
@@ -90,16 +92,16 @@ class MultiNetPatchExtractor(Extractor, object):
super(MultiNetPatchExtractor, self).__init__(config_file = config_file,
config_group = config_group,
model_file = model_file,
- function_kwargs = function_kwargs,
patches_num = patches_num,
- patch_reshape_parameters = patch_reshape_parameters)
+ patch_reshape_parameters = patch_reshape_parameters,
+ color_input_flag = color_input_flag)
self.config_file = config_file
self.config_group = config_group
self.model_file = model_file
- self.function_kwargs = function_kwargs
self.patches_num = patches_num
self.patch_reshape_parameters = patch_reshape_parameters
+ self.color_input_flag = color_input_flag
# =========================================================================
@@ -121,8 +123,15 @@ class MultiNetPatchExtractor(Extractor, object):
Feature vector.
"""
+ # kwargs for the transform_and_net_forward function:
+ function_kwargs = {}
+ function_kwargs["config_file"] = self.config_file
+ function_kwargs["config_group"] = self.config_group
+ function_kwargs["model_file"] = self.model_file
+ function_kwargs["color_input_flag"] = self.color_input_flag
+
# convert all values in the dictionary to the list if not a list already:
- function_kwargs = {k:[v] if not isinstance(v, list) else v for (k,v) in self.function_kwargs.items()}
+ function_kwargs = {k:[v] if not isinstance(v, list) else v for (k,v) in function_kwargs.items()}
# compute all possible key-value combinations:
function_kwargs = combinations(function_kwargs) # function_kwargs is now a list with kwargs
diff --git a/bob/ip/pytorch_extractor/test.py b/bob/ip/pytorch_extractor/test.py
index 00152049f838b5bc68923a824b735f9ef9051ae6..410b139973494a244c20eccfa6b9d396b02db763 100644
--- a/bob/ip/pytorch_extractor/test.py
+++ b/bob/ip/pytorch_extractor/test.py
@@ -72,23 +72,16 @@ def test_multi_net_patch_extractor():
MODEL_FILE = [pkg_resources.resource_filename('bob.ip.pytorch_extractor',
'test_data/conv_ae_model_pretrain_celeba_tune_batl_full_face.pth')]
- # kwargs for the transform_and_net_forward function:
- FUNCTION_KWARGS = {}
- FUNCTION_KWARGS["config_file"] = CONFIG_FILE
- FUNCTION_KWARGS["config_group"] = CONFIG_GROUP
- FUNCTION_KWARGS["model_file"] = MODEL_FILE
- FUNCTION_KWARGS["color_input_flag"] = True
-
PATCHES_NUM = [0] # patches to be used in the feature extraction process
-
PATCH_RESHAPE_PARAMETERS = [3, 128, 128] # reshape vectorized patches to this dimensions before passing to the Network
+ COLOR_INPUT_FLAG = True
image_extractor = MultiNetPatchExtractor(config_file = CONFIG_FILE,
config_group = CONFIG_GROUP,
model_file = MODEL_FILE,
- function_kwargs = FUNCTION_KWARGS,
patches_num = PATCHES_NUM,
- patch_reshape_parameters = PATCH_RESHAPE_PARAMETERS)
+ patch_reshape_parameters = PATCH_RESHAPE_PARAMETERS,
+ color_input_flag = COLOR_INPUT_FLAG)
# pass through encoder only, compute latent vector:
latent_vector = image_extractor(patch_flat)