From bf25f8063976026f1bbcd3a7ebce015b1b9dedf4 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 12 Apr 2023 15:57:58 +0200
Subject: [PATCH] Get epochs from checkpointer and moved checkpointer selection
 code

---
 src/ptbench/scripts/train.py      | 44 +++++++------------------------
 src/ptbench/utils/checkpointer.py | 41 ++++++++++++++++++++++++++++
 2 files changed, 51 insertions(+), 34 deletions(-)
 create mode 100644 src/ptbench/utils/checkpointer.py

diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 6003eaff..f5bd3a7a 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -2,14 +2,14 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-import os
-
 import click
 
 from clapper.click import ConfigCommand, ResourceOption, verbosity_option
 from clapper.logging import setup
 from pytorch_lightning import seed_everything
 
+from ..utils.checkpointer import get_checkpoint
+
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
@@ -432,41 +432,17 @@ def train(
 
     arguments = {}
     arguments["max_epoch"] = epochs
+    arguments["epoch"] = 0
 
-    last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
-    best_checkpoint_path = os.path.join(
-        output_folder, "model_lowest_valid_loss.ckpt"
-    )
-
-    if resume_from == "last":
-        if os.path.isfile(last_checkpoint_path):
-            checkpoint = last_checkpoint_path
-            logger.info(f"Resuming training from {resume_from} checkpoint")
-        else:
-            raise FileNotFoundError(
-                f"Could not find checkpoint {last_checkpoint_path}"
-            )
-
-    elif resume_from == "best":
-        if os.path.isfile(best_checkpoint_path):
-            checkpoint = last_checkpoint_path
-            logger.info(f"Resuming training from {resume_from} checkpoint")
-        else:
-            raise FileNotFoundError(
-                f"Could not find checkpoint {best_checkpoint_path}"
-            )
-
-    elif resume_from is None:
-        checkpoint = None
+    checkpoint_file = get_checkpoint(output_folder, resume_from)
 
-    else:
-        if os.path.isfile(resume_from):
-            checkpoint = resume_from
-            logger.info(f"Resuming training from checkpoint {resume_from}")
-        else:
-            raise FileNotFoundError(f"Could not find checkpoint {resume_from}")
+    # We only load the checkpoint to get some information about its state. The actual loading of the model is done in trainer.fit()
+    if checkpoint_file is not None:
+        checkpoint = torch.load(checkpoint_file)
+        arguments["epoch"] = checkpoint["epoch"]
 
     logger.info("Training for {} epochs".format(arguments["max_epoch"]))
+    logger.info("Continuing from epoch {}".format(arguments["epoch"]))
 
     run(
         model=model,
@@ -479,5 +455,5 @@ def train(
         output_folder=output_folder,
         monitoring_interval=monitoring_interval,
         batch_chunk_count=batch_chunk_count,
-        checkpoint=checkpoint,
+        checkpoint=checkpoint_file,
     )
diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py
new file mode 100644
index 00000000..a30cba94
--- /dev/null
+++ b/src/ptbench/utils/checkpointer.py
@@ -0,0 +1,41 @@
+import logging
+import os
+
+logger = logging.getLogger(__name__)
+
+
+def get_checkpoint(output_folder, resume_from):
+    last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
+    best_checkpoint_path = os.path.join(
+        output_folder, "model_lowest_valid_loss.ckpt"
+    )
+
+    if resume_from == "last":
+        if os.path.isfile(last_checkpoint_path):
+            checkpoint_file = last_checkpoint_path
+            logger.info(f"Resuming training from {resume_from} checkpoint")
+        else:
+            raise FileNotFoundError(
+                f"Could not find checkpoint {last_checkpoint_path}"
+            )
+
+    elif resume_from == "best":
+        if os.path.isfile(best_checkpoint_path):
+            checkpoint_file = last_checkpoint_path
+            logger.info(f"Resuming training from {resume_from} checkpoint")
+        else:
+            raise FileNotFoundError(
+                f"Could not find checkpoint {best_checkpoint_path}"
+            )
+
+    elif resume_from is None:
+        checkpoint_file = None
+
+    else:
+        if os.path.isfile(resume_from):
+            checkpoint_file = resume_from
+            logger.info(f"Resuming training from checkpoint {resume_from}")
+        else:
+            raise FileNotFoundError(f"Could not find checkpoint {resume_from}")
+
+    return checkpoint_file
-- 
GitLab