From 010a014734acaa5a30d00b3986f13623e5338860 Mon Sep 17 00:00:00 2001
From: "ogueler@idiap.ch" <ogueler@vws110.idiap.ch>
Date: Thu, 16 Mar 2023 16:46:40 +0100
Subject: [PATCH] gave precedence to checkpoints over weights

---
 src/ptbench/scripts/train.py | 21 +++++++++++++++------
 tests/test_cli.py            |  4 ++--
 2 files changed, 17 insertions(+), 8 deletions(-)

diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index af7b23df..206eaf90 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -531,13 +531,22 @@ def train(
 
     # Load pretrained weights if needed
     if weight is not None:
-        if weight.startswith("http"):
-            logger.info(f"Temporarily downloading '{weight}'...")
-            f = download_to_tempfile(weight, progress=True)
-            weight_fullpath = os.path.abspath(f.name)
+        if checkpointer.has_checkpoint():
+            logger.warning(
+                "Weights are being ignored because a checkpoint already exists. "
+                "Weights from checkpoint will be loaded instead."
+            )
+            extra_checkpoint_data = checkpointer.load()
         else:
-            weight_fullpath = os.path.abspath(weight)
-        extra_checkpoint_data = checkpointer.load(weight_fullpath, strict=False)
+            if weight.startswith("http"):
+                logger.info(f"Temporarily downloading '{weight}'...")
+                f = download_to_tempfile(weight, progress=True)
+                weight_fullpath = os.path.abspath(f.name)
+            else:
+                weight_fullpath = os.path.abspath(weight)
+            extra_checkpoint_data = checkpointer.load(
+                weight_fullpath, strict=False
+            )
     else:
         extra_checkpoint_data = checkpointer.load()
 
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 3e7d4b06..5bc2fa1d 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -416,7 +416,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir):
     runner = CliRunner()
 
     with stdout_logging() as buf:
-        output_folder = str(temporary_basedir / "results")
+        output_folder = str(temporary_basedir / "results/signstotb")
         result = runner.invoke(
             train,
             [
@@ -518,7 +518,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir):
     runner = CliRunner()
 
     with stdout_logging() as buf:
-        output_folder = str(temporary_basedir / "results")
+        output_folder = str(temporary_basedir / "results/logreg")
         result = runner.invoke(
             train,
             [
-- 
GitLab