From edadec6d0e76a3c7f05b3c61f773113e4f2b2aa3 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 11 Jul 2023 19:24:29 +0200
Subject: [PATCH] Save and restore normalizer from checkpoint

---
 src/ptbench/models/pasa.py   |  7 +++++++
 src/ptbench/scripts/train.py | 30 +++++++++++++++++-------------
 2 files changed, 24 insertions(+), 13 deletions(-)

diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 20bbb0dd..d6dd23ee 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -185,6 +185,13 @@ class Pasa(pl.LightningModule):
 
         return x
 
+    def on_save_checkpoint(self, checkpoint):
+        checkpoint["normalizer"] = self.normalizer
+
+    def on_load_checkpoint(self, checkpoint):
+        logger.info("Restoring normalizer from checkpoint.")
+        self.normalizer = checkpoint["normalizer"]
+
     def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
         """Initializes the input normalizer for the current model.
 
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index bffeebdb..d026e922 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -252,16 +252,6 @@ def train(
     datamodule.prepare_data()
     datamodule.setup(stage="fit")
 
-    # Sets the model normalizer with the unaugmented-train-subset.
-    # this call may be a NOOP, if the model was pre-trained and expects
-    # different weights for the normalisation layer.
-    if hasattr(model, "set_normalizer"):
-        model.set_normalizer(datamodule.unshuffled_train_dataloader())
-    else:
-        logger.warning(
-            f"Model {model.name} has no 'set_normalizer' method. Skipping."
-        )
-
     # If asked, rebalances the loss criterion based on the relative proportion
     # of class examples available in the training set.  Also affects the
     # validation loss if a validation set is available on the data module.
@@ -276,9 +266,23 @@ def train(
         )
 
     logger.info(f"Training for at most {epochs} epochs.")
-    # We only load the checkpoint to get some information about its state. The
-    # actual loading of the model is done in trainer.fit()
-    if checkpoint_file is not None:
+
+    arguments = {}
+    arguments["max_epoch"] = epochs
+    arguments["epoch"] = 0
+
+    if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"):
+        # Sets the model normalizer with the unaugmented-train-subset.
+        # this call may be a NOOP, if the model was pre-trained and expects
+        # different weights for the normalisation layer.
+        if hasattr(model, "set_normalizer"):
+            model.set_normalizer(datamodule.unshuffled_train_dataloader())
+        else:
+            logger.warning(
+                f"Model {model.name} has no 'set_normalizer' method. Skipping."
+            )
+    else:
+        # Normalizer will be loaded during model.on_load_checkpoint
         checkpoint = torch.load(checkpoint_file)
         start_epoch = checkpoint["epoch"]
         logger.info(f"Resuming from epoch {start_epoch}...")
-- 
GitLab