From 0152c9c37ec9720d3d43f2229b67652dd30cd202 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Thu, 14 Dec 2023 14:52:31 +0100
Subject: [PATCH] [utils.checkpointer] Refactor checkpoint saving and loading

---
 src/ptbench/data/datamodule.py    |   5 +
 src/ptbench/engine/trainer.py     |  52 +++++----
 src/ptbench/scripts/experiment.py |  15 ++-
 src/ptbench/scripts/train.py      |  81 ++++++++------
 src/ptbench/utils/checkpointer.py | 172 +++++++++++++++++++++---------
 tests/test_cli.py                 |  61 +++++++----
 6 files changed, 253 insertions(+), 133 deletions(-)

diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 3d7e1a64..b954a000 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -580,8 +580,10 @@ class ConcatDataModule(lightning.LightningDataModule):
 
         if value < 0:
             num_workers = 0
+
         else:
             num_workers = value or multiprocessing.cpu_count()
+
         self._dataloader_multiproc["num_workers"] = num_workers
 
         if num_workers > 0 and sys.platform == "darwin":
@@ -589,6 +591,9 @@ class ConcatDataModule(lightning.LightningDataModule):
                 "multiprocessing_context"
             ] = multiprocessing.get_context("spawn")
 
+        # keep workers hanging around if we have multiple
+        self._dataloader_multiproc["persistent_workers"] = True
+
     @property
     def model_transforms(self) -> list[Transform] | None:
         """Transforms required to fit data into the model.
diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index fccf47b8..b4f878fe 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -13,6 +13,7 @@ import lightning.pytorch.callbacks
 import lightning.pytorch.loggers
 import torch.nn
 
+from ..utils.checkpointer import CHECKPOINT_ALIASES
 from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
 from .callbacks import LoggingCallback
 from .device import DeviceManager
@@ -47,13 +48,13 @@ def save_model_summary(
     summary_path = output_folder / "model-summary.txt"
     logger.info(f"Saving model summary at {summary_path}...")
     with summary_path.open("w") as f:
-        summary = lightning.pytorch.utilities.model_summary.ModelSummary(
+        summary = lightning.pytorch.utilities.model_summary.ModelSummary(  # type: ignore
             model, max_depth=-1
         )
         f.write(str(summary))
     return (
         summary,
-        lightning.pytorch.utilities.model_summary.ModelSummary(
+        lightning.pytorch.utilities.model_summary.ModelSummary(  # type: ignore
             model
         ).total_parameters,
     )
@@ -99,13 +100,13 @@ def static_information_to_csv(
 def run(
     model: lightning.pytorch.LightningModule,
     datamodule: lightning.pytorch.LightningDataModule,
-    checkpoint_period: int,
+    validation_period: int,
     device_manager: DeviceManager,
     max_epochs: int,
     output_folder: pathlib.Path,
     monitoring_interval: int | float,
     batch_chunk_count: int,
-    checkpoint: str | None,
+    checkpoint: pathlib.Path | None,
 ):
     """Fits a CNN model using supervised learning and save it to disk.
 
@@ -122,9 +123,15 @@ def run(
     datamodule
         The lightning datamodule to use for training **and** validation
 
-    checkpoint_period
-        Save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
-        not save intermediary checkpoints.
+    validation_period
+        Number of epochs after which validation happens.  By default, we run
+        validation after every training epoch (period=1).  You can change this
+        to make validation more sparse, by increasing the validation period.
+        Notice that this affects checkpoint saving.  While checkpoints are
+        created after every training step (the last training step always
+        triggers the overriding of latest checkpoint), and that this process is
+        independent of validation runs, evaluation of the 'best' model obtained
+        so far based on those will be influenced by this setting.
 
     device_manager
         An internal device representation, to be used for training and
@@ -177,17 +184,22 @@ def run(
         logging_level=logging.ERROR,
     )
 
-    checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint(
-        output_folder,
-        "model_lowest_valid_loss",
-        save_last=True,
+    # This checkpointer will operate at the end of every validation epoch
+    # (which happens at each checkpoint period), it will then save the lowest
+    # validation loss model observed.  It will also save the last trained model
+    checkpoint_minvalloss_callback = lightning.pytorch.callbacks.ModelCheckpoint(
+        dirpath=output_folder,
+        filename=CHECKPOINT_ALIASES["best"],
+        save_last=True,  # will (re)create the last trained model, at every iteration
         monitor="loss/validation",
         mode="min",
-        save_on_train_epoch_end=True,
-        every_n_epochs=checkpoint_period,
+        save_on_train_epoch_end=True,  # run checks at the end of validation
+        every_n_epochs=validation_period,  # frequency at which it would check the "monitor"
+        enable_version_counter=False,  # no versioning of aliased checkpoints
     )
-
-    checkpoint_callback.CHECKPOINT_NAME_LAST = "model_final_epoch"
+    checkpoint_minvalloss_callback.CHECKPOINT_NAME_LAST = CHECKPOINT_ALIASES[  # type: ignore
+        "periodic"
+    ]
 
     # write static information to a CSV file
     static_information_to_csv(
@@ -204,9 +216,13 @@ def run(
             max_epochs=max_epochs,
             accumulate_grad_batches=batch_chunk_count,
             logger=tensorboard_logger,
-            check_val_every_n_epoch=1,
+            check_val_every_n_epoch=validation_period,
             log_every_n_steps=len(datamodule.train_dataloader()),
-            callbacks=[LoggingCallback(resource_monitor), checkpoint_callback],
+            callbacks=[
+                LoggingCallback(resource_monitor),
+                checkpoint_minvalloss_callback,
+            ],
         )
 
-        _ = trainer.fit(model, datamodule, ckpt_path=checkpoint)
+        checkpoint_str = checkpoint if checkpoint is None else str(checkpoint)
+        _ = trainer.fit(model, datamodule, ckpt_path=checkpoint_str)
diff --git a/src/ptbench/scripts/experiment.py b/src/ptbench/scripts/experiment.py
index 55c24260..1d520fe3 100644
--- a/src/ptbench/scripts/experiment.py
+++ b/src/ptbench/scripts/experiment.py
@@ -42,7 +42,7 @@ def experiment(
     batch_chunk_count,
     drop_incomplete_batch,
     datamodule,
-    checkpoint_period,
+    validation_period,
     device,
     cache_samples,
     seed,
@@ -84,7 +84,7 @@ def experiment(
         batch_chunk_count=batch_chunk_count,
         drop_incomplete_batch=drop_incomplete_batch,
         datamodule=datamodule,
-        checkpoint_period=checkpoint_period,
+        validation_period=validation_period,
         device=device,
         cache_samples=cache_samples,
         seed=seed,
@@ -111,13 +111,12 @@ def experiment(
 
     logger.info("Started predicting")
 
-    from .predict import predict
+    from ..utils.checkpointer import get_checkpoint_to_run_inference
+
+    model_file = get_checkpoint_to_run_inference(train_output_folder)
+    logger.info(f"Found `{str(model_file)}`. Continuing...")
 
-    # preferably, we use the best model on the validation set
-    # otherwise, we get the last saved model
-    model_file = train_output_folder / "model_lowest_valid_loss.ckpt"
-    if not model_file.exists():
-        model_file = train_output_folder / "model_final_epoch.ckpt"
+    from .predict import predict
 
     predictions_output = output_folder / "predictions.json"
 
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index c5c2a3ac..e8149a41 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -125,16 +125,21 @@ def reusable_options(f):
         cls=ResourceOption,
     )
     @click.option(
-        "--checkpoint-period",
+        "--validation-period",
         "-p",
-        help="""Number of epochs after which a checkpoint is saved. A value of
-        zero will disable check-pointing. If checkpointing is enabled and
-        training stops, it is automatically resumed from the last saved
-        checkpoint if training is restarted with the same configuration.""",
+        help="""Number of epochs after which validation happens.  By default,
+        we run validation after every training epoch (period=1).  You can
+        change this to make validation more sparse, by increasing the
+        validation period. Notice that this affects checkpoint saving.  While
+        checkpoints are created after every training step (the last training
+        step always triggers the overriding of latest checkpoint), and that
+        this process is independent of validation runs, evaluation of the
+        'best' model obtained so far based on those will be influenced by this
+        setting.""",
         show_default=True,
-        required=False,
-        default=None,
-        type=click.IntRange(min=0),
+        required=True,
+        default=1,
+        type=click.IntRange(min=1),
         cls=ResourceOption,
     )
     @click.option(
@@ -183,27 +188,19 @@ def reusable_options(f):
         "--monitoring-interval",
         "-I",
         help="""Time between checks for the use of resources during each training
-        epoch.  An interval of 5 seconds, for example, will lead to CPU and GPU
-        resources being probed every 5 seconds during each training epoch.
-        Values registered in the training logs correspond to averages (or maxima)
-        observed through possibly many probes in each epoch.  Notice that setting a
-        very small value may cause the probing process to become extremely busy,
-        potentially biasing the overall perception of resource usage.""",
+        epoch, in seconds.  An interval of 5 seconds, for example, will lead to
+        CPU and GPU resources being probed every 5 seconds during each training
+        epoch. Values registered in the training logs correspond to averages
+        (or maxima) observed through possibly many probes in each epoch.
+        Notice that setting a very small value may cause the probing process to
+        become extremely busy, potentially biasing the overall perception of
+        resource usage.""",
         type=click.FloatRange(min=0.1),
         show_default=True,
         required=True,
         default=5.0,
         cls=ResourceOption,
     )
-    @click.option(
-        "--resume-from",
-        help="""Which checkpoint to resume training from. If set, can be one of
-        `best`, `last`, or a path to a model checkpoint.""",
-        type=click.STRING,
-        required=False,
-        default=None,
-        cls=ResourceOption,
-    )
     @click.option(
         "--balance-classes/--no-balance-classes",
         "-B/-N",
@@ -244,13 +241,12 @@ def train(
     batch_chunk_count,
     drop_incomplete_batch,
     datamodule,
-    checkpoint_period,
+    validation_period,
     device,
     cache_samples,
     seed,
     parallel,
     monitoring_interval,
-    resume_from,
     balance_classes,
     **_,
 ) -> None:
@@ -263,20 +259,31 @@ def train(
     resume the procedure in case it stops abruptly.
     """
 
+    import os
+
     import torch
 
     from lightning.pytorch import seed_everything
 
     from ..engine.device import DeviceManager
     from ..engine.trainer import run
-    from ..utils.checkpointer import get_checkpoint
+    from ..utils.checkpointer import get_checkpoint_to_resume_training
     from .utils import save_sh_command
 
+    checkpoint_file = None
+    if os.path.isdir(output_folder):
+        try:
+            checkpoint_file = get_checkpoint_to_resume_training(output_folder)
+        except FileNotFoundError:
+            logger.info(
+                f"Folder {output_folder} already exists, but I did not"
+                f" find any usable checkpoint file to resume training"
+                f" from. Starting from scratch..."
+            )
+
     save_sh_command(output_folder / "command.sh")
     seed_everything(seed)
 
-    checkpoint_file = get_checkpoint(output_folder, resume_from)
-
     # reset datamodule with user configurable options
     datamodule.set_chunk_size(batch_size, batch_chunk_count)
     datamodule.drop_incomplete_batch = drop_incomplete_batch
@@ -307,25 +314,31 @@ def train(
     arguments["epoch"] = 0
 
     if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"):
-        # Sets the model normalizer with the unaugmented-train-subset.
-        # this call may be a NOOP, if the model was pre-trained and expects
-        # different weights for the normalisation layer.
+        # Sets the model normalizer with the unaugmented-train-subset if we are
+        # starting from scratch and/or the model does not contain its own
+        # checkpoint loading strategy (e.g. a pytorch stock checkpoint). This
+        # call may be a NOOP, if the model comes from outside this framework,
+        # and expects different weights for the normalisation layer.
         if hasattr(model, "set_normalizer"):
             model.set_normalizer(datamodule.unshuffled_train_dataloader())
         else:
             logger.warning(
-                f"Model {model.name} has no 'set_normalizer' method. Skipping."
+                f"Model {model.name} has no `set_normalizer` method. "
+                "Skipping normalization setup (unsupported external model)."
             )
     else:
         # Normalizer will be loaded during model.on_load_checkpoint
         checkpoint = torch.load(checkpoint_file)
         start_epoch = checkpoint["epoch"]
-        logger.info(f"Resuming from epoch {start_epoch}...")
+        logger.info(
+            f"Resuming from epoch {start_epoch} "
+            f"(checkpoint file: `{str(checkpoint_file)}`)..."
+        )
 
     run(
         model=model,
         datamodule=datamodule,
-        checkpoint_period=checkpoint_period,
+        validation_period=validation_period,
         device_manager=DeviceManager(device),
         max_epochs=epochs,
         output_folder=output_folder,
diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py
index 8f6685b8..988e611e 100644
--- a/src/ptbench/utils/checkpointer.py
+++ b/src/ptbench/utils/checkpointer.py
@@ -4,76 +4,146 @@
 
 import logging
 import pathlib
+import re
 import typing
 
 logger = logging.getLogger(__name__)
 
 
-def get_checkpoint(
-    output_folder: pathlib.Path,
-    resume_from: typing.Literal["last", "best"] | str | None,
-) -> str | None:
-    """Gets a checkpoint file.
+CHECKPOINT_ALIASES = {
+    "best": "model-at-lowest-validation-loss-{epoch}",
+    "periodic": "model-at-{epoch}",
+}
+"""Standard paths where checkpoints may be (if produced with this
+framework)."""
 
-    Can return the best or last checkpoint, or a checkpoint at a specific path.
-    Ensures the checkpoint exists, raising an error if it is not the case.
+CHECKPOINT_EXTENSION = ".ckpt"
 
-    If ``resume_from`` is ``None``, checks the output directory if a "last"
-    checkpoint file already exists and returns it. If no checkpoint is found,
-    returns ``None``.
 
-    ``resume_from`` can also be a path to an existing checkpoint file.  In this
-    case, we check it and return if it exists.
+def _get_checkpoint_from_alias(
+    path: pathlib.Path,
+    alias: typing.Literal["best", "periodic"],
+) -> pathlib.Path:
+    """Gets an existing checkpoint file path.
+
+    This function can search for names matching the checkpoint alias "stem"
+    (ie. the prefix), and then assumes a dash "-" and a number follows that
+    prefix before the expected file extension.  The number is parsed and
+    considred to be an epoch number.  The latest file (the file containing the
+    highest epoch number) is returned.
+
+    If only one file is present matching the alias characteristics, then it is
+    returned.
 
 
     Parameters
     ----------
-    output_folder
-        Folder in which checkpoints are stored.
-    resume_from
-        Which model to get. Can be one of "best", "last", or a path to a checkpoint.
-        If ``None``, gets the last checkpoint if it exists, otherwise returns
-        ``None`` (signal to start from scratch).
+    path
+        Folder in which may contain checkpoint
+    alias
+        Can be one of "best" or "periodic".
 
 
     Returns
     -------
-        Path to the requested checkpoint (as a plain string) or ``None`` (start
-        from scratch).
+        Path to the requested checkpoint, or ``None``, if no checkpoint file
+        matching specifications is found on the provided path.
 
 
     Raises
     ------
     FileNotFoundError
-        In case a required file cannot be found.
+        In case it cannot find any file on the provided path matching the given
+        specifications.
     """
-    # standard paths where checkpoints may be (if produced with this framework)
-    last_path = output_folder / "model_final_epoch.ckpt"
-    best_path = output_folder / "model_lowest_valid_loss.ckpt"
-
-    if resume_from in ("last", "best"):
-        use_file = last_path if resume_from == "last" else best_path
-        if use_file.is_file():
-            logger.info(f"Found checkpoint at `{str(use_file)}`")
-            return str(use_file)
-        else:
-            raise FileNotFoundError(
-                f"Could not find a checkpoint file at `{str(use_file)}`"
-            )
-
-    elif resume_from is None:
-        # use-case: user is re-starting a crashed/cancelled job
-        if last_path.is_file():
-            logger.info(f"Found checkpoint at `{str(last_path)}`")
-            return str(last_path)
-        else:
-            return None
-
-    elif isinstance(resume_from, str):
-        if pathlib.Path(resume_from).is_file():
-            logger.info(f"Found checkpoint at `{resume_from}`")
-            return resume_from
-        else:
-            raise FileNotFoundError(
-                f"Could not find a checkpoint file at `{resume_from}`"
-            )
+
+    template = path / (CHECKPOINT_ALIASES[alias] + CHECKPOINT_EXTENSION)
+
+    if template.exists():
+        return template
+
+    # otherwise, we see if we are looking for a template instead, in which case
+    # we must pick the latest.
+    assert "{epoch}" in str(
+        template
+    ), f"Template `{str(template)}` does not contain the keyword `{{epoch}}`"
+
+    pattern = re.compile(
+        template.name.replace("{epoch}", r"epoch=(?P<epoch>\d+)")
+    )
+    highest = -1
+    for f in template.parent.iterdir():
+        match = pattern.match(f.name)
+        if match is not None:
+            value = int(match.group("epoch"))
+            if value > highest:
+                highest = value
+
+    if highest != -1:
+        return template.with_name(
+            template.name.replace("{epoch}", f"epoch={highest}")
+        )
+
+    raise FileNotFoundError(
+        f"A file matching `{str(template)}` specifications was not found"
+    )
+
+
+def get_checkpoint_to_resume_training(
+    path: pathlib.Path,
+):
+    """Returns the best checkpoint file path to resume training from.
+
+    Parameters
+    ----------
+    path
+        The base directory containing either the "periodic" checkpoint to start
+        the training session from.
+
+
+    Returns
+    -------
+        Path to a checkpoint file that exists on disk
+
+
+    Raises
+    ------
+    FileNotFoundError
+        If none of the checkpoints can be found on the provided directory.
+    """
+
+    return _get_checkpoint_from_alias(path, "periodic")
+
+
+def get_checkpoint_to_run_inference(
+    path: pathlib.Path,
+):
+    """Returns the best checkpoint file path to run inference with.
+
+    Parameters
+    ----------
+    path
+        The base directory containing either the "best", "last" or "periodic"
+        checkpoint to start the training session from.
+
+
+    Returns
+    -------
+        Path to a checkpoint file that exists on disk
+
+
+    Raises
+    ------
+    FileNotFoundError
+        If none of the checkpoints can be found on the provided directory.
+    """
+
+    try:
+        _get_checkpoint_from_alias(path, "best")
+    except FileNotFoundError:
+        logger.error(
+            "Did not find lowest-validation-loss model to run inference "
+            "from.  Trying to search for the last periodically saved model..."
+        )
+
+    return _get_checkpoint_from_alias(path, "periodic")
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 4d408bac..f90462c5 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -193,11 +193,15 @@ def test_compare_vis_help():
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
 def test_train_pasa_montgomery(temporary_basedir):
     from ptbench.scripts.train import train
+    from ptbench.utils.checkpointer import (
+        CHECKPOINT_EXTENSION,
+        _get_checkpoint_from_alias,
+    )
 
     runner = CliRunner()
 
     with stdout_logging() as buf:
-        output_folder = str(temporary_basedir / "results")
+        output_folder = temporary_basedir / "results"
         result = runner.invoke(
             train,
             [
@@ -206,17 +210,17 @@ def test_train_pasa_montgomery(temporary_basedir):
                 "-vv",
                 "--epochs=1",
                 "--batch-size=1",
-                f"--output-folder={output_folder}",
+                f"--output-folder={str(output_folder)}",
             ],
         )
         _assert_exit_0(result)
 
-        assert os.path.exists(
-            os.path.join(output_folder, "model_final_epoch.ckpt")
-        )
-        assert os.path.exists(
-            os.path.join(output_folder, "model_lowest_valid_loss.ckpt")
-        )
+        # asserts checkpoints are there, or raises FileNotFoundError
+        last = _get_checkpoint_from_alias(output_folder, "periodic")
+        assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
+        best = _get_checkpoint_from_alias(output_folder, "best")
+        assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
+
         assert os.path.exists(os.path.join(output_folder, "constants.csv"))
         assert (
             len(
@@ -254,10 +258,14 @@ def test_train_pasa_montgomery(temporary_basedir):
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
 def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
     from ptbench.scripts.train import train
+    from ptbench.utils.checkpointer import (
+        CHECKPOINT_EXTENSION,
+        _get_checkpoint_from_alias,
+    )
 
     runner = CliRunner()
 
-    output_folder = str(temporary_basedir / "results/pasa_checkpoint")
+    output_folder = temporary_basedir / "results" / "pasa_checkpoint"
     result0 = runner.invoke(
         train,
         [
@@ -266,15 +274,17 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
             "-vv",
             "--epochs=1",
             "--batch-size=1",
-            f"--output-folder={output_folder}",
+            f"--output-folder={str(output_folder)}",
         ],
     )
     _assert_exit_0(result0)
 
-    assert os.path.exists(os.path.join(output_folder, "model_final_epoch.ckpt"))
-    assert os.path.exists(
-        os.path.join(output_folder, "model_lowest_valid_loss.ckpt")
-    )
+    # asserts checkpoints are there, or raises FileNotFoundError
+    last = _get_checkpoint_from_alias(output_folder, "periodic")
+    assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
+    best = _get_checkpoint_from_alias(output_folder, "best")
+    assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
+
     assert os.path.exists(os.path.join(output_folder, "constants.csv"))
     assert (
         len(
@@ -301,12 +311,11 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
         )
         _assert_exit_0(result)
 
-        assert os.path.exists(
-            os.path.join(output_folder, "model_final_epoch.ckpt")
-        )
-        assert os.path.exists(
-            os.path.join(output_folder, "model_lowest_valid_loss.ckpt")
-        )
+        # asserts checkpoints are there, or raises FileNotFoundError
+        last = _get_checkpoint_from_alias(output_folder, "periodic")
+        assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION)
+        best = _get_checkpoint_from_alias(output_folder, "best")
+
         assert os.path.exists(os.path.join(output_folder, "constants.csv"))
 
         assert (
@@ -348,11 +357,19 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
 def test_predict_pasa_montgomery(temporary_basedir, datadir):
     from ptbench.scripts.predict import predict
+    from ptbench.utils.checkpointer import (
+        CHECKPOINT_EXTENSION,
+        _get_checkpoint_from_alias,
+    )
 
     runner = CliRunner()
 
     with stdout_logging() as buf:
-        output = str(temporary_basedir / "predictions")
+        output = temporary_basedir / "predictions"
+        last = _get_checkpoint_from_alias(
+            temporary_basedir / "results", "periodic"
+        )
+        assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
         result = runner.invoke(
             predict,
             [
@@ -360,7 +377,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir):
                 "montgomery",
                 "-vv",
                 "--batch-size=1",
-                f"--weight={str(temporary_basedir / 'results' / 'model_final_epoch.ckpt')}",
+                f"--weight={str(last)}",
                 f"--output={output}",
             ],
         )
-- 
GitLab