diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 20bbb0dd9d06a122c81187762b8bbca0d04470fd..d6dd23ee02923490247d03b75fbf2c167aef57dd 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -185,6 +185,13 @@ class Pasa(pl.LightningModule): return x + def on_save_checkpoint(self, checkpoint): + checkpoint["normalizer"] = self.normalizer + + def on_load_checkpoint(self, checkpoint): + logger.info("Restoring normalizer from checkpoint.") + self.normalizer = checkpoint["normalizer"] + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: """Initializes the input normalizer for the current model. diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index bffeebdb5d2cae835b7cd17706ebac4b93ef35fe..d026e92236a21c7663fe884bd49b73079b7b55d2 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -252,16 +252,6 @@ def train( datamodule.prepare_data() datamodule.setup(stage="fit") - # Sets the model normalizer with the unaugmented-train-subset. - # this call may be a NOOP, if the model was pre-trained and expects - # different weights for the normalisation layer. - if hasattr(model, "set_normalizer"): - model.set_normalizer(datamodule.unshuffled_train_dataloader()) - else: - logger.warning( - f"Model {model.name} has no 'set_normalizer' method. Skipping." - ) - # If asked, rebalances the loss criterion based on the relative proportion # of class examples available in the training set. Also affects the # validation loss if a validation set is available on the data module. @@ -276,9 +266,23 @@ def train( ) logger.info(f"Training for at most {epochs} epochs.") - # We only load the checkpoint to get some information about its state. The - # actual loading of the model is done in trainer.fit() - if checkpoint_file is not None: + + arguments = {} + arguments["max_epoch"] = epochs + arguments["epoch"] = 0 + + if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"): + # Sets the model normalizer with the unaugmented-train-subset. + # this call may be a NOOP, if the model was pre-trained and expects + # different weights for the normalisation layer. + if hasattr(model, "set_normalizer"): + model.set_normalizer(datamodule.unshuffled_train_dataloader()) + else: + logger.warning( + f"Model {model.name} has no 'set_normalizer' method. Skipping." + ) + else: + # Normalizer will be loaded during model.on_load_checkpoint checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint["epoch"] logger.info(f"Resuming from epoch {start_epoch}...")