Skip to content
Snippets Groups Projects
Commit 381520b9 authored by Olegs NIKISINS's avatar Olegs NIKISINS
Browse files

Added a simple functionality allowing to load pre-trained models from urls

parent 4a357a02
No related branches found
No related tags found
1 merge request!3Added MultiNetPatchClassifier extractor and utils, temp fix in LightCNN9
Pipeline #26462 passed
......@@ -15,6 +15,7 @@ import numpy as np
from bob.ip.pytorch_extractor.utils import reshape_flat_patches
from bob.ip.pytorch_extractor.utils import combinations
from bob.ip.pytorch_extractor.utils import transform_and_net_forward
from bob.ip.pytorch_extractor.utils import load_pretrained_model
# =============================================================================
# Main body:
......@@ -86,6 +87,26 @@ class MultiNetPatchExtractor(Extractor, object):
of the size ``(n_samples, H, W)``. The tensor to be passed through
the net will be of the size ``(n_samples, 1, H, W)``.
Default: ``False``.
urls : [str]
List of URLs to download the pretrained models from.
If models are not available in the locations specified in the
``model_file`` list, the system will try to download them from
``urls``. The downloaded models **will be placed to the locations**
specified in ``model_file`` list.
For example, the pretrained model for the autoencoder pre-trained on
RGB faces of the size (3(channels) x 128 x 128) and fine-tuned
on the BW-NIR-D data can be found here:
["https://www.idiap.ch/software/bob/data/bob/bob.ip.pytorch_extractor/master/"
"conv_ae_model_pretrain_celeba_tune_batl_full_face.pth.tar.gz"]
Default: None
archive_extension : str
Extension of the archived files to download from above ``urls``.
Default: '.tar.gz'
"""
# =========================================================================
......@@ -94,7 +115,9 @@ class MultiNetPatchExtractor(Extractor, object):
model_file,
patches_num,
patch_reshape_parameters = None,
color_input_flag = False):
color_input_flag = False,
urls = None,
archive_extension = '.tar.gz'):
"""
Init method.
"""
......@@ -104,7 +127,9 @@ class MultiNetPatchExtractor(Extractor, object):
model_file = model_file,
patches_num = patches_num,
patch_reshape_parameters = patch_reshape_parameters,
color_input_flag = color_input_flag)
color_input_flag = color_input_flag,
urls = urls,
archive_extension = archive_extension)
self.config_file = config_file
self.config_group = config_group
......@@ -112,6 +137,8 @@ class MultiNetPatchExtractor(Extractor, object):
self.patches_num = patches_num
self.patch_reshape_parameters = patch_reshape_parameters
self.color_input_flag = color_input_flag
self.urls = urls
self.archive_extension = archive_extension
# =========================================================================
......@@ -154,15 +181,26 @@ class MultiNetPatchExtractor(Extractor, object):
features_all_patches = []
# make sure the model_file and urls are not None:
if self.model_file is None:
self.model_file = [self.model_file]
if self.urls is None:
self.urls = [self.urls]
for idx, patch in enumerate(patches_3d):
# try to load the model if not available, do nothing if available:
load_pretrained_model(model_path = self.model_file[self.patches_num[idx]],
url = self.urls[self.patches_num[idx]],
archive_extension = self.archive_extension)
if len(function_kwargs) == 1: # patches are passed through the same network:
features = transform_and_net_forward(feature = patch, **function_kwargs[0])
else: # patches are passed through different networks:
features = transform_and_net_forward(feature = patch, **function_kwargs[self.patches_num(idx)])
features = transform_and_net_forward(feature = patch, **function_kwargs[self.patches_num[idx]])
# print ("The model we use for patch {} is:".format(str(idx)))
# print (function_kwargs[self.patches_num(idx)]["model_file"])
......
......@@ -80,7 +80,8 @@ def test_multi_net_patch_extractor():
model_file = MODEL_FILE,
patches_num = PATCHES_NUM,
patch_reshape_parameters = PATCH_RESHAPE_PARAMETERS,
color_input_flag = COLOR_INPUT_FLAG)
color_input_flag = COLOR_INPUT_FLAG,
urls = None)
# pass through encoder only, compute latent vector:
latent_vector = image_extractor(patch_flat)
......
......@@ -24,6 +24,10 @@ from torch.autograd import Variable
import itertools as it
import bob.io.base
from bob.extension.download import download_and_unzip
# =============================================================================
def reshape_flat_patches(patches, patch_reshape_parameters = None):
......@@ -287,6 +291,7 @@ def transform_and_net_forward(feature,
# Load the pre-trained model into the network
if model_file is not None:
model_state = torch.load(model_file, map_location=lambda storage,loc:storage)
# Initialize the state of the model:
......@@ -305,3 +310,37 @@ def transform_and_net_forward(feature,
net_output = net_output.flatten()
return net_output.astype(np.float)
def load_pretrained_model(model_path, url, archive_extension = '.tar.gz'):
"""
Loads the model from the given ``url``, if the model specified in the
``model_path`` is unavailable. The model will be saved to the location
defined in the ``model_path`` string, and will have the same filename.
Arguments
---------
model_path : str
Absolute file name pointing to the model.
url : str
URL to download the model from.
archive_extension : str
Extension of the archived file. Default: '.tar.gz'
"""
if model_path is not None and url is not None:
if not os.path.exists(model_path):
print ("The model is not available at " + model_path)
print ("Trying to load it from " + url)
bob.io.base.create_directories_safe(os.path.split(model_path)[0])
archive_file = os.path.join(os.path.split(model_path)[0],
os.path.splitext(os.path.split(model_path)[1])[0] + archive_extension)
download_and_unzip(url, archive_file)
  • Maintainer

    @onikisins Can you add an option to just download as well, without unzipping? Its already available as bob.extension.download.download_file

  • Please register or sign in to reply
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment