From 09e50ca46eed7d10d047c02e651d634c62dcc5c4 Mon Sep 17 00:00:00 2001
From: "ogueler@idiap.ch" <ogueler@vws110.idiap.ch>
Date: Thu, 16 Mar 2023 20:53:45 +0100
Subject: [PATCH] removed weight loading functionality

---
 src/ptbench/scripts/train.py | 33 +-------------
 tests/test_cli.py            | 87 +++++++++++++++++++++++++++++++++---
 2 files changed, 82 insertions(+), 38 deletions(-)

diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 206eaf90..bafeb030 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -288,13 +288,6 @@ def set_reproducible_cuda():
     default=-1,
     cls=ResourceOption,
 )
-@click.option(
-    "--weight",
-    "-w",
-    help="Path or URL to pretrained model file (.pth extension)",
-    required=False,
-    cls=ResourceOption,
-)
 @click.option(
     "--normalization",
     "-n",
@@ -337,7 +330,6 @@ def train(
     device,
     seed,
     parallel,
-    weight,
     normalization,
     monitoring_interval,
     **_,
@@ -363,7 +355,6 @@ def train(
     from ..configs.datasets import get_positive_weights, get_samples_weights
     from ..engine.trainer import run
     from ..utils.checkpointer import Checkpointer
-    from ..utils.download import download_to_tempfile
 
     device = setup_pytorch_device(device)
 
@@ -528,29 +519,7 @@ def train(
     # Initialize epoch information
     arguments = {}
     arguments["epoch"] = 0
-
-    # Load pretrained weights if needed
-    if weight is not None:
-        if checkpointer.has_checkpoint():
-            logger.warning(
-                "Weights are being ignored because a checkpoint already exists. "
-                "Weights from checkpoint will be loaded instead."
-            )
-            extra_checkpoint_data = checkpointer.load()
-        else:
-            if weight.startswith("http"):
-                logger.info(f"Temporarily downloading '{weight}'...")
-                f = download_to_tempfile(weight, progress=True)
-                weight_fullpath = os.path.abspath(f.name)
-            else:
-                weight_fullpath = os.path.abspath(weight)
-            extra_checkpoint_data = checkpointer.load(
-                weight_fullpath, strict=False
-            )
-    else:
-        extra_checkpoint_data = checkpointer.load()
-
-    # Update epoch information with checkpoint data
+    extra_checkpoint_data = checkpointer.load()
     arguments.update(extra_checkpoint_data)
     arguments["max_epoch"] = epochs
 
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 5bc2fa1d..31edf501 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -226,6 +226,83 @@ def test_train_pasa_montgomery(temporary_basedir):
             )
 
 
+@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
+def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
+    from ptbench.scripts.train import train
+
+    runner = CliRunner()
+
+    output_folder = str(temporary_basedir / "results/pasa_checkpoint")
+    result0 = runner.invoke(
+        train,
+        [
+            "pasa",
+            "montgomery",
+            "-vv",
+            "--epochs=1",
+            "--batch-size=1",
+            "--normalization=current",
+            f"--output-folder={output_folder}",
+        ],
+    )
+    _assert_exit_0(result0)
+
+    assert os.path.exists(os.path.join(output_folder, "model_final_epoch.pth"))
+    assert os.path.exists(
+        os.path.join(output_folder, "model_lowest_valid_loss.pth")
+    )
+    assert os.path.exists(os.path.join(output_folder, "last_checkpoint"))
+    assert os.path.exists(os.path.join(output_folder, "constants.csv"))
+    assert os.path.exists(os.path.join(output_folder, "trainlog.csv"))
+    assert os.path.exists(os.path.join(output_folder, "model_summary.txt"))
+
+    with stdout_logging() as buf:
+        result = runner.invoke(
+            train,
+            [
+                "pasa",
+                "montgomery",
+                "-vv",
+                "--epochs=2",
+                "--batch-size=1",
+                "--normalization=current",
+                f"--output-folder={output_folder}",
+            ],
+        )
+        _assert_exit_0(result)
+
+        assert os.path.exists(
+            os.path.join(output_folder, "model_final_epoch.pth")
+        )
+        assert os.path.exists(
+            os.path.join(output_folder, "model_lowest_valid_loss.pth")
+        )
+        assert os.path.exists(os.path.join(output_folder, "last_checkpoint"))
+        assert os.path.exists(os.path.join(output_folder, "constants.csv"))
+        assert os.path.exists(os.path.join(output_folder, "trainlog.csv"))
+        assert os.path.exists(os.path.join(output_folder, "model_summary.txt"))
+
+        keywords = {
+            r"^Found \(dedicated\) '__train__' set for training$": 1,
+            r"^Found \(dedicated\) '__valid__' set for validation$": 1,
+            r"^Continuing from epoch 1$": 1,
+            r"^Saving model summary at.*$": 1,
+            r"^Model has.*$": 1,
+            r"^Saving checkpoint": 2,
+            r"^Total training time:": 1,
+            r"^Z-normalization with mean": 1,
+        }
+        buf.seek(0)
+        logging_output = buf.read()
+
+        for k, v in keywords.items():
+            assert _str_counter(k, logging_output) == v, (
+                f"Count for string '{k}' appeared "
+                f"({_str_counter(k, logging_output)}) "
+                f"instead of the expected {v}:\nOutput:\n{logging_output}"
+            )
+
+
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
 def test_predict_pasa_montgomery(temporary_basedir, datadir):
     from ptbench.scripts.predict import predict
@@ -423,9 +500,8 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir):
                 "signs_to_tb",
                 "montgomery_rs",
                 "-vv",
-                "--epochs=15",
+                "--epochs=1",
                 "--batch-size=1",
-                f"--weight={str(datadir / 'lfs' / 'models' / 'signstotb.pth')}",
                 f"--output-folder={output_folder}",
             ],
         )
@@ -445,7 +521,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir):
         keywords = {
             r"^Found \(dedicated\) '__train__' set for training$": 1,
             r"^Found \(dedicated\) '__valid__' set for validation$": 1,
-            r"^Continuing from epoch 14$": 1,
+            r"^Continuing from epoch 0$": 1,
             r"^Saving model summary at.*$": 1,
             r"^Model has.*$": 1,
             r"^Saving checkpoint": 2,
@@ -525,9 +601,8 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir):
                 "logistic_regression",
                 "montgomery_rs",
                 "-vv",
-                "--epochs=43",
+                "--epochs=1",
                 "--batch-size=1",
-                f"--weight={str(datadir / 'lfs' / 'models' / 'logreg.pth')}",
                 f"--output-folder={output_folder}",
             ],
         )
@@ -547,7 +622,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir):
         keywords = {
             r"^Found \(dedicated\) '__train__' set for training$": 1,
             r"^Found \(dedicated\) '__valid__' set for validation$": 1,
-            r"^Continuing from epoch 42$": 1,
+            r"^Continuing from epoch 0$": 1,
             r"^Saving model summary at.*$": 1,
             r"^Model has.*$": 1,
             r"^Saving checkpoint": 2,
-- 
GitLab