From 010a014734acaa5a30d00b3986f13623e5338860 Mon Sep 17 00:00:00 2001 From: "ogueler@idiap.ch" <ogueler@vws110.idiap.ch> Date: Thu, 16 Mar 2023 16:46:40 +0100 Subject: [PATCH] gave precedence to checkpoints over weights --- src/ptbench/scripts/train.py | 21 +++++++++++++++------ tests/test_cli.py | 4 ++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index af7b23df..206eaf90 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -531,13 +531,22 @@ def train( # Load pretrained weights if needed if weight is not None: - if weight.startswith("http"): - logger.info(f"Temporarily downloading '{weight}'...") - f = download_to_tempfile(weight, progress=True) - weight_fullpath = os.path.abspath(f.name) + if checkpointer.has_checkpoint(): + logger.warning( + "Weights are being ignored because a checkpoint already exists. " + "Weights from checkpoint will be loaded instead." + ) + extra_checkpoint_data = checkpointer.load() else: - weight_fullpath = os.path.abspath(weight) - extra_checkpoint_data = checkpointer.load(weight_fullpath, strict=False) + if weight.startswith("http"): + logger.info(f"Temporarily downloading '{weight}'...") + f = download_to_tempfile(weight, progress=True) + weight_fullpath = os.path.abspath(f.name) + else: + weight_fullpath = os.path.abspath(weight) + extra_checkpoint_data = checkpointer.load( + weight_fullpath, strict=False + ) else: extra_checkpoint_data = checkpointer.load() diff --git a/tests/test_cli.py b/tests/test_cli.py index 3e7d4b06..5bc2fa1d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -416,7 +416,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): runner = CliRunner() with stdout_logging() as buf: - output_folder = str(temporary_basedir / "results") + output_folder = str(temporary_basedir / "results/signstotb") result = runner.invoke( train, [ @@ -518,7 +518,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir): runner = CliRunner() with stdout_logging() as buf: - output_folder = str(temporary_basedir / "results") + output_folder = str(temporary_basedir / "results/logreg") result = runner.invoke( train, [ -- GitLab