diff --git a/bob/ip/pytorch_extractor/test.py b/bob/ip/pytorch_extractor/test.py
index 410b139973494a244c20eccfa6b9d396b02db763..8f2253529dfb9ff150ef1af73502b53d1e513e22 100644
--- a/bob/ip/pytorch_extractor/test.py
+++ b/bob/ip/pytorch_extractor/test.py
@@ -69,8 +69,7 @@ def test_multi_net_patch_extractor():
CONFIG_GROUP = "bob.ip.pytorch_extractor"
# use specific/unique model for each patch. Models pre-trained on CelebA and fine-tuned (3 layers) on BATL:
- MODEL_FILE = [pkg_resources.resource_filename('bob.ip.pytorch_extractor',
- 'test_data/conv_ae_model_pretrain_celeba_tune_batl_full_face.pth')]
+ MODEL_FILE = None
PATCHES_NUM = [0] # patches to be used in the feature extraction process
PATCH_RESHAPE_PARAMETERS = [3, 128, 128] # reshape vectorized patches to this dimensions before passing to the Network
diff --git a/bob/ip/pytorch_extractor/test_data/conv_ae_model_pretrain_celeba_tune_batl_full_face.pth b/bob/ip/pytorch_extractor/test_data/conv_ae_model_pretrain_celeba_tune_batl_full_face.pth
deleted file mode 100644
index 1cb97a261f2901a1632e50d651d1c136c8146da8..0000000000000000000000000000000000000000
Binary files a/bob/ip/pytorch_extractor/test_data/conv_ae_model_pretrain_celeba_tune_batl_full_face.pth and /dev/null differ
diff --git a/bob/ip/pytorch_extractor/utils.py b/bob/ip/pytorch_extractor/utils.py
index 5a6ece540e189d17bc9ddb69ed9e2a06d1f58531..5cdc47c59b424fb796ae9f4eb60dc8f616c211bb 100644
--- a/bob/ip/pytorch_extractor/utils.py
+++ b/bob/ip/pytorch_extractor/utils.py
@@ -286,10 +286,11 @@ def transform_and_net_forward(feature,
local_model = _init_the_network(config_module)
# Load the pre-trained model into the network
- model_state = torch.load(model_file, map_location=lambda storage,loc:storage)
+ if model_file is not None:
+ model_state = torch.load(model_file, map_location=lambda storage,loc:storage)
- # Initialize the state of the model:
- local_model.load_state_dict(model_state)
+ # Initialize the state of the model:
+ local_model.load_state_dict(model_state)
# Model is used for evaluation only:
local_model.train(False)