diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index af7b23dfd93277625b9dc84bc3858f93591f7237..206eaf90fbc6030e2eb0a30b232196d94315e214 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -531,13 +531,22 @@ def train( # Load pretrained weights if needed if weight is not None: - if weight.startswith("http"): - logger.info(f"Temporarily downloading '{weight}'...") - f = download_to_tempfile(weight, progress=True) - weight_fullpath = os.path.abspath(f.name) + if checkpointer.has_checkpoint(): + logger.warning( + "Weights are being ignored because a checkpoint already exists. " + "Weights from checkpoint will be loaded instead." + ) + extra_checkpoint_data = checkpointer.load() else: - weight_fullpath = os.path.abspath(weight) - extra_checkpoint_data = checkpointer.load(weight_fullpath, strict=False) + if weight.startswith("http"): + logger.info(f"Temporarily downloading '{weight}'...") + f = download_to_tempfile(weight, progress=True) + weight_fullpath = os.path.abspath(f.name) + else: + weight_fullpath = os.path.abspath(weight) + extra_checkpoint_data = checkpointer.load( + weight_fullpath, strict=False + ) else: extra_checkpoint_data = checkpointer.load() diff --git a/tests/test_cli.py b/tests/test_cli.py index 3e7d4b0622c8184d3a74842829c015c0ebecb688..5bc2fa1dc11ab3bc624e030265bbdd6c21b1fa37 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -416,7 +416,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): runner = CliRunner() with stdout_logging() as buf: - output_folder = str(temporary_basedir / "results") + output_folder = str(temporary_basedir / "results/signstotb") result = runner.invoke( train, [ @@ -518,7 +518,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir): runner = CliRunner() with stdout_logging() as buf: - output_folder = str(temporary_basedir / "results") + output_folder = str(temporary_basedir / "results/logreg") result = runner.invoke( train, [