From 2381830dc16f480358bdd0a825ce32546ab8deb2 Mon Sep 17 00:00:00 2001 From: "ogueler@idiap.ch" <ogueler@vws110.idiap.ch> Date: Thu, 16 Mar 2023 14:20:00 +0100 Subject: [PATCH] re-added weight functionality + updated tests --- src/ptbench/scripts/train.py | 16 +++++++++++++++- tests/test_cli.py | 8 ++++---- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index d8110c67..af7b23df 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -363,6 +363,7 @@ def train( from ..configs.datasets import get_positive_weights, get_samples_weights from ..engine.trainer import run from ..utils.checkpointer import Checkpointer + from ..utils.download import download_to_tempfile device = setup_pytorch_device(device) @@ -527,7 +528,20 @@ def train( # Initialize epoch information arguments = {} arguments["epoch"] = 0 - extra_checkpoint_data = checkpointer.load() + + # Load pretrained weights if needed + if weight is not None: + if weight.startswith("http"): + logger.info(f"Temporarily downloading '{weight}'...") + f = download_to_tempfile(weight, progress=True) + weight_fullpath = os.path.abspath(f.name) + else: + weight_fullpath = os.path.abspath(weight) + extra_checkpoint_data = checkpointer.load(weight_fullpath, strict=False) + else: + extra_checkpoint_data = checkpointer.load() + + # Update epoch information with checkpoint data arguments.update(extra_checkpoint_data) arguments["max_epoch"] = epochs diff --git a/tests/test_cli.py b/tests/test_cli.py index 9c42d58e..3e7d4b06 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -423,7 +423,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): "signs_to_tb", "montgomery_rs", "-vv", - "--epochs=1", + "--epochs=15", "--batch-size=1", f"--weight={str(datadir / 'lfs' / 'models' / 'signstotb.pth')}", f"--output-folder={output_folder}", @@ -445,7 +445,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): keywords = { r"^Found \(dedicated\) '__train__' set for training$": 1, r"^Found \(dedicated\) '__valid__' set for validation$": 1, - r"^Continuing from epoch 0$": 1, + r"^Continuing from epoch 14$": 1, r"^Saving model summary at.*$": 1, r"^Model has.*$": 1, r"^Saving checkpoint": 2, @@ -525,7 +525,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir): "logistic_regression", "montgomery_rs", "-vv", - "--epochs=1", + "--epochs=43", "--batch-size=1", f"--weight={str(datadir / 'lfs' / 'models' / 'logreg.pth')}", f"--output-folder={output_folder}", @@ -547,7 +547,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir): keywords = { r"^Found \(dedicated\) '__train__' set for training$": 1, r"^Found \(dedicated\) '__valid__' set for validation$": 1, - r"^Continuing from epoch 0$": 1, + r"^Continuing from epoch 42$": 1, r"^Saving model summary at.*$": 1, r"^Model has.*$": 1, r"^Saving checkpoint": 2, -- GitLab