From 09e50ca46eed7d10d047c02e651d634c62dcc5c4 Mon Sep 17 00:00:00 2001 From: "ogueler@idiap.ch" <ogueler@vws110.idiap.ch> Date: Thu, 16 Mar 2023 20:53:45 +0100 Subject: [PATCH] removed weight loading functionality --- src/ptbench/scripts/train.py | 33 +------------- tests/test_cli.py | 87 +++++++++++++++++++++++++++++++++--- 2 files changed, 82 insertions(+), 38 deletions(-) diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 206eaf90..bafeb030 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -288,13 +288,6 @@ def set_reproducible_cuda(): default=-1, cls=ResourceOption, ) -@click.option( - "--weight", - "-w", - help="Path or URL to pretrained model file (.pth extension)", - required=False, - cls=ResourceOption, -) @click.option( "--normalization", "-n", @@ -337,7 +330,6 @@ def train( device, seed, parallel, - weight, normalization, monitoring_interval, **_, @@ -363,7 +355,6 @@ def train( from ..configs.datasets import get_positive_weights, get_samples_weights from ..engine.trainer import run from ..utils.checkpointer import Checkpointer - from ..utils.download import download_to_tempfile device = setup_pytorch_device(device) @@ -528,29 +519,7 @@ def train( # Initialize epoch information arguments = {} arguments["epoch"] = 0 - - # Load pretrained weights if needed - if weight is not None: - if checkpointer.has_checkpoint(): - logger.warning( - "Weights are being ignored because a checkpoint already exists. " - "Weights from checkpoint will be loaded instead." - ) - extra_checkpoint_data = checkpointer.load() - else: - if weight.startswith("http"): - logger.info(f"Temporarily downloading '{weight}'...") - f = download_to_tempfile(weight, progress=True) - weight_fullpath = os.path.abspath(f.name) - else: - weight_fullpath = os.path.abspath(weight) - extra_checkpoint_data = checkpointer.load( - weight_fullpath, strict=False - ) - else: - extra_checkpoint_data = checkpointer.load() - - # Update epoch information with checkpoint data + extra_checkpoint_data = checkpointer.load() arguments.update(extra_checkpoint_data) arguments["max_epoch"] = epochs diff --git a/tests/test_cli.py b/tests/test_cli.py index 5bc2fa1d..31edf501 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -226,6 +226,83 @@ def test_train_pasa_montgomery(temporary_basedir): ) +@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") +def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): + from ptbench.scripts.train import train + + runner = CliRunner() + + output_folder = str(temporary_basedir / "results/pasa_checkpoint") + result0 = runner.invoke( + train, + [ + "pasa", + "montgomery", + "-vv", + "--epochs=1", + "--batch-size=1", + "--normalization=current", + f"--output-folder={output_folder}", + ], + ) + _assert_exit_0(result0) + + assert os.path.exists(os.path.join(output_folder, "model_final_epoch.pth")) + assert os.path.exists( + os.path.join(output_folder, "model_lowest_valid_loss.pth") + ) + assert os.path.exists(os.path.join(output_folder, "last_checkpoint")) + assert os.path.exists(os.path.join(output_folder, "constants.csv")) + assert os.path.exists(os.path.join(output_folder, "trainlog.csv")) + assert os.path.exists(os.path.join(output_folder, "model_summary.txt")) + + with stdout_logging() as buf: + result = runner.invoke( + train, + [ + "pasa", + "montgomery", + "-vv", + "--epochs=2", + "--batch-size=1", + "--normalization=current", + f"--output-folder={output_folder}", + ], + ) + _assert_exit_0(result) + + assert os.path.exists( + os.path.join(output_folder, "model_final_epoch.pth") + ) + assert os.path.exists( + os.path.join(output_folder, "model_lowest_valid_loss.pth") + ) + assert os.path.exists(os.path.join(output_folder, "last_checkpoint")) + assert os.path.exists(os.path.join(output_folder, "constants.csv")) + assert os.path.exists(os.path.join(output_folder, "trainlog.csv")) + assert os.path.exists(os.path.join(output_folder, "model_summary.txt")) + + keywords = { + r"^Found \(dedicated\) '__train__' set for training$": 1, + r"^Found \(dedicated\) '__valid__' set for validation$": 1, + r"^Continuing from epoch 1$": 1, + r"^Saving model summary at.*$": 1, + r"^Model has.*$": 1, + r"^Saving checkpoint": 2, + r"^Total training time:": 1, + r"^Z-normalization with mean": 1, + } + buf.seek(0) + logging_output = buf.read() + + for k, v in keywords.items(): + assert _str_counter(k, logging_output) == v, ( + f"Count for string '{k}' appeared " + f"({_str_counter(k, logging_output)}) " + f"instead of the expected {v}:\nOutput:\n{logging_output}" + ) + + @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_predict_pasa_montgomery(temporary_basedir, datadir): from ptbench.scripts.predict import predict @@ -423,9 +500,8 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): "signs_to_tb", "montgomery_rs", "-vv", - "--epochs=15", + "--epochs=1", "--batch-size=1", - f"--weight={str(datadir / 'lfs' / 'models' / 'signstotb.pth')}", f"--output-folder={output_folder}", ], ) @@ -445,7 +521,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): keywords = { r"^Found \(dedicated\) '__train__' set for training$": 1, r"^Found \(dedicated\) '__valid__' set for validation$": 1, - r"^Continuing from epoch 14$": 1, + r"^Continuing from epoch 0$": 1, r"^Saving model summary at.*$": 1, r"^Model has.*$": 1, r"^Saving checkpoint": 2, @@ -525,9 +601,8 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir): "logistic_regression", "montgomery_rs", "-vv", - "--epochs=43", + "--epochs=1", "--batch-size=1", - f"--weight={str(datadir / 'lfs' / 'models' / 'logreg.pth')}", f"--output-folder={output_folder}", ], ) @@ -547,7 +622,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir): keywords = { r"^Found \(dedicated\) '__train__' set for training$": 1, r"^Found \(dedicated\) '__valid__' set for validation$": 1, - r"^Continuing from epoch 42$": 1, + r"^Continuing from epoch 0$": 1, r"^Saving model summary at.*$": 1, r"^Model has.*$": 1, r"^Saving checkpoint": 2, -- GitLab