diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 41b10a291e008250747d34a76169de9be57d32fc..21a3ae6636b5392e844a2da27bcec86119dcb45b 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -63,7 +63,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ) @click.option( "--device", - "-d", help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', show_default=True, required=True, diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 07577a3474e32e9556da44f024e61c5691c649f0..bafeb0303899086cd1b00cfe52d4d0896f7045b8 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -259,7 +259,6 @@ def set_reproducible_cuda(): ) @click.option( "--device", - "-d", help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', show_default=True, required=True, @@ -289,13 +288,6 @@ def set_reproducible_cuda(): default=-1, cls=ResourceOption, ) -@click.option( - "--weight", - "-w", - help="Path or URL to pretrained model file (.pth extension)", - required=False, - cls=ResourceOption, -) @click.option( "--normalization", "-n", @@ -338,7 +330,6 @@ def train( device, seed, parallel, - weight, normalization, monitoring_interval, **_, @@ -364,7 +355,6 @@ def train( from ..configs.datasets import get_positive_weights, get_samples_weights from ..engine.trainer import run from ..utils.checkpointer import Checkpointer - from ..utils.download import download_to_tempfile device = setup_pytorch_device(device) @@ -526,18 +516,11 @@ def train( # Checkpointer checkpointer = Checkpointer(model, optimizer, path=output_folder) - # Load pretrained weights if needed - if weight is not None: - if weight.startswith("http"): - logger.info(f"Temporarily downloading '{weight}'...") - f = download_to_tempfile(weight, progress=True) - weight_fullpath = os.path.abspath(f.name) - else: - weight_fullpath = os.path.abspath(weight) - checkpointer.load(weight_fullpath, strict=False) - + # Initialize epoch information arguments = {} arguments["epoch"] = 0 + extra_checkpoint_data = checkpointer.load() + arguments.update(extra_checkpoint_data) arguments["max_epoch"] = epochs logger.info("Training for {} epochs".format(arguments["max_epoch"])) diff --git a/tests/test_cli.py b/tests/test_cli.py index 9c42d58e1d6faae7bffb0c5c448e00e38e95b770..31edf50194435d9ba064557c16892ebff467cfd4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -226,6 +226,83 @@ def test_train_pasa_montgomery(temporary_basedir): ) +@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") +def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): + from ptbench.scripts.train import train + + runner = CliRunner() + + output_folder = str(temporary_basedir / "results/pasa_checkpoint") + result0 = runner.invoke( + train, + [ + "pasa", + "montgomery", + "-vv", + "--epochs=1", + "--batch-size=1", + "--normalization=current", + f"--output-folder={output_folder}", + ], + ) + _assert_exit_0(result0) + + assert os.path.exists(os.path.join(output_folder, "model_final_epoch.pth")) + assert os.path.exists( + os.path.join(output_folder, "model_lowest_valid_loss.pth") + ) + assert os.path.exists(os.path.join(output_folder, "last_checkpoint")) + assert os.path.exists(os.path.join(output_folder, "constants.csv")) + assert os.path.exists(os.path.join(output_folder, "trainlog.csv")) + assert os.path.exists(os.path.join(output_folder, "model_summary.txt")) + + with stdout_logging() as buf: + result = runner.invoke( + train, + [ + "pasa", + "montgomery", + "-vv", + "--epochs=2", + "--batch-size=1", + "--normalization=current", + f"--output-folder={output_folder}", + ], + ) + _assert_exit_0(result) + + assert os.path.exists( + os.path.join(output_folder, "model_final_epoch.pth") + ) + assert os.path.exists( + os.path.join(output_folder, "model_lowest_valid_loss.pth") + ) + assert os.path.exists(os.path.join(output_folder, "last_checkpoint")) + assert os.path.exists(os.path.join(output_folder, "constants.csv")) + assert os.path.exists(os.path.join(output_folder, "trainlog.csv")) + assert os.path.exists(os.path.join(output_folder, "model_summary.txt")) + + keywords = { + r"^Found \(dedicated\) '__train__' set for training$": 1, + r"^Found \(dedicated\) '__valid__' set for validation$": 1, + r"^Continuing from epoch 1$": 1, + r"^Saving model summary at.*$": 1, + r"^Model has.*$": 1, + r"^Saving checkpoint": 2, + r"^Total training time:": 1, + r"^Z-normalization with mean": 1, + } + buf.seek(0) + logging_output = buf.read() + + for k, v in keywords.items(): + assert _str_counter(k, logging_output) == v, ( + f"Count for string '{k}' appeared " + f"({_str_counter(k, logging_output)}) " + f"instead of the expected {v}:\nOutput:\n{logging_output}" + ) + + @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_predict_pasa_montgomery(temporary_basedir, datadir): from ptbench.scripts.predict import predict @@ -416,7 +493,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): runner = CliRunner() with stdout_logging() as buf: - output_folder = str(temporary_basedir / "results") + output_folder = str(temporary_basedir / "results/signstotb") result = runner.invoke( train, [ @@ -425,7 +502,6 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): "-vv", "--epochs=1", "--batch-size=1", - f"--weight={str(datadir / 'lfs' / 'models' / 'signstotb.pth')}", f"--output-folder={output_folder}", ], ) @@ -518,7 +594,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir): runner = CliRunner() with stdout_logging() as buf: - output_folder = str(temporary_basedir / "results") + output_folder = str(temporary_basedir / "results/logreg") result = runner.invoke( train, [ @@ -527,7 +603,6 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir): "-vv", "--epochs=1", "--batch-size=1", - f"--weight={str(datadir / 'lfs' / 'models' / 'logreg.pth')}", f"--output-folder={output_folder}", ], )