From 2381830dc16f480358bdd0a825ce32546ab8deb2 Mon Sep 17 00:00:00 2001
From: "ogueler@idiap.ch" <ogueler@vws110.idiap.ch>
Date: Thu, 16 Mar 2023 14:20:00 +0100
Subject: [PATCH] re-added weight functionality + updated tests

---
 src/ptbench/scripts/train.py | 16 +++++++++++++++-
 tests/test_cli.py            |  8 ++++----
 2 files changed, 19 insertions(+), 5 deletions(-)

diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index d8110c67..af7b23df 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -363,6 +363,7 @@ def train(
     from ..configs.datasets import get_positive_weights, get_samples_weights
     from ..engine.trainer import run
     from ..utils.checkpointer import Checkpointer
+    from ..utils.download import download_to_tempfile
 
     device = setup_pytorch_device(device)
 
@@ -527,7 +528,20 @@ def train(
     # Initialize epoch information
     arguments = {}
     arguments["epoch"] = 0
-    extra_checkpoint_data = checkpointer.load()
+
+    # Load pretrained weights if needed
+    if weight is not None:
+        if weight.startswith("http"):
+            logger.info(f"Temporarily downloading '{weight}'...")
+            f = download_to_tempfile(weight, progress=True)
+            weight_fullpath = os.path.abspath(f.name)
+        else:
+            weight_fullpath = os.path.abspath(weight)
+        extra_checkpoint_data = checkpointer.load(weight_fullpath, strict=False)
+    else:
+        extra_checkpoint_data = checkpointer.load()
+
+    # Update epoch information with checkpoint data
     arguments.update(extra_checkpoint_data)
     arguments["max_epoch"] = epochs
 
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 9c42d58e..3e7d4b06 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -423,7 +423,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir):
                 "signs_to_tb",
                 "montgomery_rs",
                 "-vv",
-                "--epochs=1",
+                "--epochs=15",
                 "--batch-size=1",
                 f"--weight={str(datadir / 'lfs' / 'models' / 'signstotb.pth')}",
                 f"--output-folder={output_folder}",
@@ -445,7 +445,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir):
         keywords = {
             r"^Found \(dedicated\) '__train__' set for training$": 1,
             r"^Found \(dedicated\) '__valid__' set for validation$": 1,
-            r"^Continuing from epoch 0$": 1,
+            r"^Continuing from epoch 14$": 1,
             r"^Saving model summary at.*$": 1,
             r"^Model has.*$": 1,
             r"^Saving checkpoint": 2,
@@ -525,7 +525,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir):
                 "logistic_regression",
                 "montgomery_rs",
                 "-vv",
-                "--epochs=1",
+                "--epochs=43",
                 "--batch-size=1",
                 f"--weight={str(datadir / 'lfs' / 'models' / 'logreg.pth')}",
                 f"--output-folder={output_folder}",
@@ -547,7 +547,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir):
         keywords = {
             r"^Found \(dedicated\) '__train__' set for training$": 1,
             r"^Found \(dedicated\) '__valid__' set for validation$": 1,
-            r"^Continuing from epoch 0$": 1,
+            r"^Continuing from epoch 42$": 1,
             r"^Saving model summary at.*$": 1,
             r"^Model has.*$": 1,
             r"^Saving checkpoint": 2,
-- 
GitLab