From 1f795f62bf288b03490a0b23436d4d27f71c36e5 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 17 Jul 2023 10:06:11 +0200
Subject: [PATCH] Resume from last checkpoint by default if exists in
 output_folder

---
 src/ptbench/utils/checkpointer.py | 21 +++++++++++++++------
 1 file changed, 15 insertions(+), 6 deletions(-)

diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py
index 81516d28..5c2f272c 100644
--- a/src/ptbench/utils/checkpointer.py
+++ b/src/ptbench/utils/checkpointer.py
@@ -1,29 +1,34 @@
 import logging
 import os
 
+import typing
 logger = logging.getLogger(__name__)
 
 
-def get_checkpoint(output_folder, resume_from):
+def get_checkpoint(output_folder: str, resume_from: typing.Literal["last", "best"] | str | None) -> str | None : 
     """Gets a checkpoint file.
 
     Can return the best or last checkpoint, or a checkpoint at a specific path.
     Ensures the checkpoint exists, raising an error if it is not the case.
 
+    If resume_from is None, checks the output directory if a checkpoint already exists and returns it.
+    If no checkpoint is found, returns None.
+
     Parameters
     ----------
 
-    output_folder : :py:class:`str`
+    output_folder:
         Directory in which checkpoints are stored.
 
-    resume_from : :py:class:`str`
+    resume_from:
         Which model to get. Can be one of "best", "last", or a path to a checkpoint.
+        If None, gets the last checkpoint if it exists, otherwise returns None
 
     Returns
     -------
 
-    checkpoint_file : :py:class:`str`
-        The requested model.
+    checkpoint_file:
+        Path to the requested checkpoint or None.
     """
     last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
     best_checkpoint_path = os.path.join(
@@ -49,7 +54,11 @@ def get_checkpoint(output_folder, resume_from):
             )
 
     elif resume_from is None:
-        checkpoint_file = None
+        if os.path.isfile(last_checkpoint_path):
+            checkpoint_file = last_checkpoint_path
+            logger.info(f"Found existing checkpoint {last_checkpoint_path}. Loading.")
+        else:
+            return None
 
     else:
         if os.path.isfile(resume_from):
-- 
GitLab