Skip to content
Snippets Groups Projects
Commit af6bcd49 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[script.predict] Support tempfile URL downloads for weights

parent 11cab326
No related branches found
No related tags found
1 merge request!12Streamlining
......@@ -2,6 +2,7 @@
# coding=utf-8
import os
import tempfile
import click
import torch
......@@ -16,6 +17,8 @@ from bob.extension.scripts.click_helper import (
from ..engine.predictor import run
from ..utils.checkpointer import DetectronCheckpointer
from .binseg import download_to_tempfile
import logging
logger = logging.getLogger(__name__)
......@@ -114,8 +117,13 @@ def predict(output_folder, model, dataset, batch_size, device, weight,
dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
# checkpointer, loads pre-fit model
weight_fullpath = os.path.abspath(weight)
if weight.startswith("http"):
logger.info(f"Temporarily downloading '{weight}'...")
f = download_to_tempfile(weight, progress=True)
weight_fullpath = os.path.abspath(f.name)
else:
weight_fullpath = os.path.abspath(weight)
weight_path = os.path.dirname(weight_fullpath)
weight_name = os.path.basename(weight_fullpath)
checkpointer = DetectronCheckpointer(model, save_dir=weight_path,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment