From af6bcd49cd407c90c113779f7959fac790873a2e Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Mon, 27 Apr 2020 21:05:54 +0200
Subject: [PATCH] [script.predict] Support tempfile URL downloads for weights

---
 bob/ip/binseg/script/predict.py | 12 ++++++++++--
 1 file changed, 10 insertions(+), 2 deletions(-)

diff --git a/bob/ip/binseg/script/predict.py b/bob/ip/binseg/script/predict.py
index bf988ec4..d06c7557 100644
--- a/bob/ip/binseg/script/predict.py
+++ b/bob/ip/binseg/script/predict.py
@@ -2,6 +2,7 @@
 # coding=utf-8
 
 import os
+import tempfile
 
 import click
 import torch
@@ -16,6 +17,8 @@ from bob.extension.scripts.click_helper import (
 from ..engine.predictor import run
 from ..utils.checkpointer import DetectronCheckpointer
 
+from .binseg import download_to_tempfile
+
 import logging
 logger = logging.getLogger(__name__)
 
@@ -114,8 +117,13 @@ def predict(output_folder, model, dataset, batch_size, device, weight,
 
     dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
 
-    # checkpointer, loads pre-fit model
-    weight_fullpath = os.path.abspath(weight)
+    if weight.startswith("http"):
+        logger.info(f"Temporarily downloading '{weight}'...")
+        f = download_to_tempfile(weight, progress=True)
+        weight_fullpath = os.path.abspath(f.name)
+    else:
+        weight_fullpath = os.path.abspath(weight)
+
     weight_path = os.path.dirname(weight_fullpath)
     weight_name = os.path.basename(weight_fullpath)
     checkpointer = DetectronCheckpointer(model, save_dir=weight_path,
-- 
GitLab