From b42b73d7ab22d6812c26ed2e202453e91b59a734 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 12 Apr 2023 10:14:18 +0200
Subject: [PATCH] Resume training from checkpoints

---
 src/ptbench/engine/trainer.py |  3 ++-
 src/ptbench/scripts/train.py  | 51 +++++++++++++++++++++++++++++------
 2 files changed, 45 insertions(+), 9 deletions(-)

diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index ddacf63c..05b37feb 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -462,6 +462,7 @@ def run(
     output_folder,
     monitoring_interval,
     batch_chunk_count,
+    checkpoint,
 ):
     """Fits a CNN model using supervised learning and save it to disk.
 
@@ -560,7 +561,7 @@ def run(
             callbacks=[LoggingCallback(resource_monitor), checkpoint_callback],
         )
 
-        _ = trainer.fit(model, data_loader, valid_loader)
+        _ = trainer.fit(model, data_loader, valid_loader, ckpt_path=checkpoint)
 
     """# write static information to a CSV file
     static_logfile_name = os.path.join(output_folder, "constants.csv")
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 12e80ae1..ea82f5a7 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -254,13 +254,6 @@ def set_reproducible_cuda():
     default=-1,
     cls=ResourceOption,
 )
-@click.option(
-    "--weight",
-    "-w",
-    help="Path or URL to pretrained model file (.pth extension)",
-    required=False,
-    cls=ResourceOption,
-)
 @click.option(
     "--normalization",
     "-n",
@@ -287,6 +280,14 @@ def set_reproducible_cuda():
     default=5.0,
     cls=ResourceOption,
 )
+@click.option(
+    "--resume-from",
+    help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a ckpt file.",
+    type=str,
+    required=False,
+    default=None,
+    cls=ResourceOption,
+)
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def train(
     model,
@@ -303,9 +304,9 @@ def train(
     device,
     seed,
     parallel,
-    weight,
     normalization,
     monitoring_interval,
+    resume_from,
     **_,
 ):
     """Trains an CNN to perform tuberculosis detection.
@@ -483,6 +484,39 @@ def train(
     arguments["epoch"] = 0
     arguments["max_epoch"] = epochs
 
+    last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
+    best_checkpoint_path = os.path.join(
+        output_folder, "model_lowest_valid_loss.ckpt"
+    )
+
+    if resume_from == "last":
+        if os.path.isfile(last_checkpoint_path):
+            checkpoint = last_checkpoint_path
+            logger.info(f"Resuming training from {resume_from} checkpoint")
+        else:
+            raise FileNotFoundError(
+                f"Could not find checkpoint {last_checkpoint_path}"
+            )
+
+    elif resume_from == "best":
+        if os.path.isfile(best_checkpoint_path):
+            checkpoint = last_checkpoint_path
+            logger.info(f"Resuming training from {resume_from} checkpoint")
+        else:
+            raise FileNotFoundError(
+                f"Could not find checkpoint {best_checkpoint_path}"
+            )
+
+    elif resume_from is None:
+        checkpoint = None
+
+    else:
+        if os.path.isfile(resume_from):
+            checkpoint = resume_from
+            logger.info(f"Resuming training from checkpoint {resume_from}")
+        else:
+            raise FileNotFoundError(f"Could not find checkpoint {resume_from}")
+
     logger.info("Training for {} epochs".format(arguments["max_epoch"]))
     logger.info("Continuing from epoch {}".format(arguments["epoch"]))
 
@@ -498,4 +532,5 @@ def train(
         output_folder=output_folder,
         monitoring_interval=monitoring_interval,
         batch_chunk_count=batch_chunk_count,
+        checkpoint=checkpoint,
     )
-- 
GitLab