From edadec6d0e76a3c7f05b3c61f773113e4f2b2aa3 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 11 Jul 2023 19:24:29 +0200 Subject: [PATCH] Save and restore normalizer from checkpoint --- src/ptbench/models/pasa.py | 7 +++++++ src/ptbench/scripts/train.py | 30 +++++++++++++++++------------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 20bbb0dd..d6dd23ee 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -185,6 +185,13 @@ class Pasa(pl.LightningModule): return x + def on_save_checkpoint(self, checkpoint): + checkpoint["normalizer"] = self.normalizer + + def on_load_checkpoint(self, checkpoint): + logger.info("Restoring normalizer from checkpoint.") + self.normalizer = checkpoint["normalizer"] + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: """Initializes the input normalizer for the current model. diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index bffeebdb..d026e922 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -252,16 +252,6 @@ def train( datamodule.prepare_data() datamodule.setup(stage="fit") - # Sets the model normalizer with the unaugmented-train-subset. - # this call may be a NOOP, if the model was pre-trained and expects - # different weights for the normalisation layer. - if hasattr(model, "set_normalizer"): - model.set_normalizer(datamodule.unshuffled_train_dataloader()) - else: - logger.warning( - f"Model {model.name} has no 'set_normalizer' method. Skipping." - ) - # If asked, rebalances the loss criterion based on the relative proportion # of class examples available in the training set. Also affects the # validation loss if a validation set is available on the data module. @@ -276,9 +266,23 @@ def train( ) logger.info(f"Training for at most {epochs} epochs.") - # We only load the checkpoint to get some information about its state. The - # actual loading of the model is done in trainer.fit() - if checkpoint_file is not None: + + arguments = {} + arguments["max_epoch"] = epochs + arguments["epoch"] = 0 + + if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"): + # Sets the model normalizer with the unaugmented-train-subset. + # this call may be a NOOP, if the model was pre-trained and expects + # different weights for the normalisation layer. + if hasattr(model, "set_normalizer"): + model.set_normalizer(datamodule.unshuffled_train_dataloader()) + else: + logger.warning( + f"Model {model.name} has no 'set_normalizer' method. Skipping." + ) + else: + # Normalizer will be loaded during model.on_load_checkpoint checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint["epoch"] logger.info(f"Resuming from epoch {start_epoch}...") -- GitLab