Skip to content
Snippets Groups Projects
Commit af6bcd49 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[script.predict] Support tempfile URL downloads for weights

parent 11cab326
No related branches found
No related tags found
1 merge request!12Streamlining
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# coding=utf-8 # coding=utf-8
import os import os
import tempfile
import click import click
import torch import torch
...@@ -16,6 +17,8 @@ from bob.extension.scripts.click_helper import ( ...@@ -16,6 +17,8 @@ from bob.extension.scripts.click_helper import (
from ..engine.predictor import run from ..engine.predictor import run
from ..utils.checkpointer import DetectronCheckpointer from ..utils.checkpointer import DetectronCheckpointer
from .binseg import download_to_tempfile
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -114,8 +117,13 @@ def predict(output_folder, model, dataset, batch_size, device, weight, ...@@ -114,8 +117,13 @@ def predict(output_folder, model, dataset, batch_size, device, weight,
dataset = dataset if isinstance(dataset, dict) else dict(test=dataset) dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
# checkpointer, loads pre-fit model if weight.startswith("http"):
weight_fullpath = os.path.abspath(weight) logger.info(f"Temporarily downloading '{weight}'...")
f = download_to_tempfile(weight, progress=True)
weight_fullpath = os.path.abspath(f.name)
else:
weight_fullpath = os.path.abspath(weight)
weight_path = os.path.dirname(weight_fullpath) weight_path = os.path.dirname(weight_fullpath)
weight_name = os.path.basename(weight_fullpath) weight_name = os.path.basename(weight_fullpath)
checkpointer = DetectronCheckpointer(model, save_dir=weight_path, checkpointer = DetectronCheckpointer(model, save_dir=weight_path,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment