-
- Downloads
Added a simple functionality allowing to load pre-trained models from urls
parent
4a357a02
No related branches found
No related tags found
Showing
- bob/ip/pytorch_extractor/MultiNetPatchExtractor.py 41 additions, 3 deletionsbob/ip/pytorch_extractor/MultiNetPatchExtractor.py
- bob/ip/pytorch_extractor/test.py 2 additions, 1 deletionbob/ip/pytorch_extractor/test.py
- bob/ip/pytorch_extractor/utils.py 39 additions, 0 deletionsbob/ip/pytorch_extractor/utils.py
... | ... | @@ -24,6 +24,10 @@ from torch.autograd import Variable |
import itertools as it | ||
import bob.io.base | ||
from bob.extension.download import download_and_unzip | ||
# ============================================================================= | ||
def reshape_flat_patches(patches, patch_reshape_parameters = None): | ||
... | ... | @@ -287,6 +291,7 @@ def transform_and_net_forward(feature, |
# Load the pre-trained model into the network | ||
if model_file is not None: | ||
model_state = torch.load(model_file, map_location=lambda storage,loc:storage) | ||
# Initialize the state of the model: | ||
... | ... | @@ -305,3 +310,37 @@ def transform_and_net_forward(feature, |
net_output = net_output.flatten() | ||
return net_output.astype(np.float) | ||
def load_pretrained_model(model_path, url, archive_extension = '.tar.gz'): | ||
""" | ||
Loads the model from the given ``url``, if the model specified in the | ||
``model_path`` is unavailable. The model will be saved to the location | ||
defined in the ``model_path`` string, and will have the same filename. | ||
Arguments | ||
--------- | ||
model_path : str | ||
Absolute file name pointing to the model. | ||
url : str | ||
URL to download the model from. | ||
archive_extension : str | ||
Extension of the archived file. Default: '.tar.gz' | ||
""" | ||
if model_path is not None and url is not None: | ||
if not os.path.exists(model_path): | ||
print ("The model is not available at " + model_path) | ||
print ("Trying to load it from " + url) | ||
bob.io.base.create_directories_safe(os.path.split(model_path)[0]) | ||
archive_file = os.path.join(os.path.split(model_path)[0], | ||
os.path.splitext(os.path.split(model_path)[1])[0] + archive_extension) | ||
download_and_unzip(url, archive_file) | ||
|