From 1f795f62bf288b03490a0b23436d4d27f71c36e5 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 17 Jul 2023 10:06:11 +0200 Subject: [PATCH] Resume from last checkpoint by default if exists in output_folder --- src/ptbench/utils/checkpointer.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py index 81516d28..5c2f272c 100644 --- a/src/ptbench/utils/checkpointer.py +++ b/src/ptbench/utils/checkpointer.py @@ -1,29 +1,34 @@ import logging import os +import typing logger = logging.getLogger(__name__) -def get_checkpoint(output_folder, resume_from): +def get_checkpoint(output_folder: str, resume_from: typing.Literal["last", "best"] | str | None) -> str | None : """Gets a checkpoint file. Can return the best or last checkpoint, or a checkpoint at a specific path. Ensures the checkpoint exists, raising an error if it is not the case. + If resume_from is None, checks the output directory if a checkpoint already exists and returns it. + If no checkpoint is found, returns None. + Parameters ---------- - output_folder : :py:class:`str` + output_folder: Directory in which checkpoints are stored. - resume_from : :py:class:`str` + resume_from: Which model to get. Can be one of "best", "last", or a path to a checkpoint. + If None, gets the last checkpoint if it exists, otherwise returns None Returns ------- - checkpoint_file : :py:class:`str` - The requested model. + checkpoint_file: + Path to the requested checkpoint or None. """ last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt") best_checkpoint_path = os.path.join( @@ -49,7 +54,11 @@ def get_checkpoint(output_folder, resume_from): ) elif resume_from is None: - checkpoint_file = None + if os.path.isfile(last_checkpoint_path): + checkpoint_file = last_checkpoint_path + logger.info(f"Found existing checkpoint {last_checkpoint_path}. Loading.") + else: + return None else: if os.path.isfile(resume_from): -- GitLab