diff --git a/bob/ip/binseg/script/predict.py b/bob/ip/binseg/script/predict.py index bf988ec4a61f8c62cb73d88d347fabd3e7608c54..d06c7557080fda133d1a3fa1e621e672667a49e3 100644 --- a/bob/ip/binseg/script/predict.py +++ b/bob/ip/binseg/script/predict.py @@ -2,6 +2,7 @@ # coding=utf-8 import os +import tempfile import click import torch @@ -16,6 +17,8 @@ from bob.extension.scripts.click_helper import ( from ..engine.predictor import run from ..utils.checkpointer import DetectronCheckpointer +from .binseg import download_to_tempfile + import logging logger = logging.getLogger(__name__) @@ -114,8 +117,13 @@ def predict(output_folder, model, dataset, batch_size, device, weight, dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) - # checkpointer, loads pre-fit model - weight_fullpath = os.path.abspath(weight) + if weight.startswith("http"): + logger.info(f"Temporarily downloading '{weight}'...") + f = download_to_tempfile(weight, progress=True) + weight_fullpath = os.path.abspath(f.name) + else: + weight_fullpath = os.path.abspath(weight) + weight_path = os.path.dirname(weight_fullpath) weight_name = os.path.basename(weight_fullpath) checkpointer = DetectronCheckpointer(model, save_dir=weight_path,