diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index d8110c67adf96be33baa669c4091b37e319fdaca..af7b23dfd93277625b9dc84bc3858f93591f7237 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -363,6 +363,7 @@ def train( from ..configs.datasets import get_positive_weights, get_samples_weights from ..engine.trainer import run from ..utils.checkpointer import Checkpointer + from ..utils.download import download_to_tempfile device = setup_pytorch_device(device) @@ -527,7 +528,20 @@ def train( # Initialize epoch information arguments = {} arguments["epoch"] = 0 - extra_checkpoint_data = checkpointer.load() + + # Load pretrained weights if needed + if weight is not None: + if weight.startswith("http"): + logger.info(f"Temporarily downloading '{weight}'...") + f = download_to_tempfile(weight, progress=True) + weight_fullpath = os.path.abspath(f.name) + else: + weight_fullpath = os.path.abspath(weight) + extra_checkpoint_data = checkpointer.load(weight_fullpath, strict=False) + else: + extra_checkpoint_data = checkpointer.load() + + # Update epoch information with checkpoint data arguments.update(extra_checkpoint_data) arguments["max_epoch"] = epochs diff --git a/tests/test_cli.py b/tests/test_cli.py index 9c42d58e1d6faae7bffb0c5c448e00e38e95b770..3e7d4b0622c8184d3a74842829c015c0ebecb688 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -423,7 +423,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): "signs_to_tb", "montgomery_rs", "-vv", - "--epochs=1", + "--epochs=15", "--batch-size=1", f"--weight={str(datadir / 'lfs' / 'models' / 'signstotb.pth')}", f"--output-folder={output_folder}", @@ -445,7 +445,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): keywords = { r"^Found \(dedicated\) '__train__' set for training$": 1, r"^Found \(dedicated\) '__valid__' set for validation$": 1, - r"^Continuing from epoch 0$": 1, + r"^Continuing from epoch 14$": 1, r"^Saving model summary at.*$": 1, r"^Model has.*$": 1, r"^Saving checkpoint": 2, @@ -525,7 +525,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir): "logistic_regression", "montgomery_rs", "-vv", - "--epochs=1", + "--epochs=43", "--batch-size=1", f"--weight={str(datadir / 'lfs' / 'models' / 'logreg.pth')}", f"--output-folder={output_folder}", @@ -547,7 +547,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir): keywords = { r"^Found \(dedicated\) '__train__' set for training$": 1, r"^Found \(dedicated\) '__valid__' set for validation$": 1, - r"^Continuing from epoch 0$": 1, + r"^Continuing from epoch 42$": 1, r"^Saving model summary at.*$": 1, r"^Model has.*$": 1, r"^Saving checkpoint": 2,