diff --git a/bob/ip/pytorch_extractor/MultiNetPatchExtractor.py b/bob/ip/pytorch_extractor/MultiNetPatchExtractor.py
index 707ad8c919f76659771d11fa3d5c5f483eff12c2..e72516badc9cb6096aad5aa5c8b8c0bcbb0fe382 100644
--- a/bob/ip/pytorch_extractor/MultiNetPatchExtractor.py
+++ b/bob/ip/pytorch_extractor/MultiNetPatchExtractor.py
@@ -41,6 +41,10 @@ class MultiNetPatchExtractor(Extractor, object):
         The path should be relative to ``config_group``,
         for example: "autoencoder/net1_batl_3_layers_partial.py".
 
+
+        ADD THE DOC on what should be in the config!!!!!!!!!!!!!!!!!!!!!!
+
+
     config_group: str
         Group/package name containing the configuration file. Usually all
         configs should be stored in this folder/place.
@@ -48,21 +52,10 @@ class MultiNetPatchExtractor(Extractor, object):
         Both ``config_file`` and ``config_group`` are used to access the
         configuration module.
 
-    model_file : str
-        A path to the model file to be used for network initialization.
+    model_file : [str]
+        A list of paths to the model files to be used for network initialization.
         The network structure is defined in the config file.
 
-    function_kwargs : dict
-
-        UPDATE THIS!!!!
-
-
-        Key-word arguments for the function defined by ``function_name``.
-        Note, that you can also specify one of the values in the dictionary
-        as a list containing multiple elements. Then, ``function_kwargs`` will
-        be different, for each patch you apply function, defined by
-        ``function_name``, to. See the ``__call__`` for more details.
-
     patches_num : [int]
         A list of inices specifying which patches will be selected for
         processing/feature vector extraction.
@@ -74,15 +67,24 @@ class MultiNetPatchExtractor(Extractor, object):
         size (256,) will be reshaped to (4,8,8) dimensions. Only 2D and 3D
         patches are supported.
         Default: None.
+
+    color_input_flag : bool
+        If set to ``True``, the input is considered to be a color image of the
+        size ``(3, H, W)``. The tensor to be passed through the net will be
+        of the size ``(1, 3, H, W)``.
+        If set to ``False``, the input is considered to be a set of BW images
+        of the size ``(n_samples, H, W)``. The tensor to be passed through
+        the net will be of the size ``(n_samples, 1, H, W)``.
+        Default: ``False``.
     """
 
     # =========================================================================
     def __init__(self, config_file,
                  config_group,
                  model_file,
-                 function_kwargs,
                  patches_num,
-                 patch_reshape_parameters = None):
+                 patch_reshape_parameters = None,
+                 color_input_flag = False):
         """
         Init method.
         """
@@ -90,16 +92,16 @@ class MultiNetPatchExtractor(Extractor, object):
         super(MultiNetPatchExtractor, self).__init__(config_file = config_file,
                                                      config_group = config_group,
                                                      model_file = model_file,
-                                                     function_kwargs = function_kwargs,
                                                      patches_num = patches_num,
-                                                     patch_reshape_parameters = patch_reshape_parameters)
+                                                     patch_reshape_parameters = patch_reshape_parameters,
+                                                     color_input_flag = color_input_flag)
 
         self.config_file = config_file
         self.config_group = config_group
         self.model_file = model_file
-        self.function_kwargs = function_kwargs
         self.patches_num = patches_num
         self.patch_reshape_parameters = patch_reshape_parameters
+        self.color_input_flag = color_input_flag
 
 
     # =========================================================================
@@ -121,8 +123,15 @@ class MultiNetPatchExtractor(Extractor, object):
             Feature vector.
         """
 
+        # kwargs for the transform_and_net_forward function:
+        function_kwargs = {}
+        function_kwargs["config_file"] = self.config_file
+        function_kwargs["config_group"] = self.config_group
+        function_kwargs["model_file"] = self.model_file
+        function_kwargs["color_input_flag"] = self.color_input_flag
+
         # convert all values in the dictionary to the list if not a list already:
-        function_kwargs = {k:[v] if not isinstance(v, list) else v for (k,v) in self.function_kwargs.items()}
+        function_kwargs = {k:[v] if not isinstance(v, list) else v for (k,v) in function_kwargs.items()}
 
         # compute all possible key-value combinations:
         function_kwargs = combinations(function_kwargs) # function_kwargs is now a list with kwargs
diff --git a/bob/ip/pytorch_extractor/test.py b/bob/ip/pytorch_extractor/test.py
index 00152049f838b5bc68923a824b735f9ef9051ae6..410b139973494a244c20eccfa6b9d396b02db763 100644
--- a/bob/ip/pytorch_extractor/test.py
+++ b/bob/ip/pytorch_extractor/test.py
@@ -72,23 +72,16 @@ def test_multi_net_patch_extractor():
     MODEL_FILE = [pkg_resources.resource_filename('bob.ip.pytorch_extractor',
                                                   'test_data/conv_ae_model_pretrain_celeba_tune_batl_full_face.pth')]
 
-    # kwargs for the transform_and_net_forward function:
-    FUNCTION_KWARGS = {}
-    FUNCTION_KWARGS["config_file"] = CONFIG_FILE
-    FUNCTION_KWARGS["config_group"] = CONFIG_GROUP
-    FUNCTION_KWARGS["model_file"] = MODEL_FILE
-    FUNCTION_KWARGS["color_input_flag"] = True
-
     PATCHES_NUM = [0] # patches to be used in the feature extraction process
-
     PATCH_RESHAPE_PARAMETERS = [3, 128, 128]  # reshape vectorized patches to this dimensions before passing to the Network
+    COLOR_INPUT_FLAG = True
 
     image_extractor = MultiNetPatchExtractor(config_file = CONFIG_FILE,
                                              config_group = CONFIG_GROUP,
                                              model_file = MODEL_FILE,
-                                             function_kwargs = FUNCTION_KWARGS,
                                              patches_num = PATCHES_NUM,
-                                             patch_reshape_parameters = PATCH_RESHAPE_PARAMETERS)
+                                             patch_reshape_parameters = PATCH_RESHAPE_PARAMETERS,
+                                             color_input_flag = COLOR_INPUT_FLAG)
 
     # pass through encoder only, compute latent vector:
     latent_vector = image_extractor(patch_flat)