Skip to content
Snippets Groups Projects
Commit edadec6d authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Save and restore normalizer from checkpoint

parent 2fbbd899
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -185,6 +185,13 @@ class Pasa(pl.LightningModule): ...@@ -185,6 +185,13 @@ class Pasa(pl.LightningModule):
return x return x
def on_save_checkpoint(self, checkpoint):
checkpoint["normalizer"] = self.normalizer
def on_load_checkpoint(self, checkpoint):
logger.info("Restoring normalizer from checkpoint.")
self.normalizer = checkpoint["normalizer"]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initializes the input normalizer for the current model. """Initializes the input normalizer for the current model.
......
...@@ -252,16 +252,6 @@ def train( ...@@ -252,16 +252,6 @@ def train(
datamodule.prepare_data() datamodule.prepare_data()
datamodule.setup(stage="fit") datamodule.setup(stage="fit")
# Sets the model normalizer with the unaugmented-train-subset.
# this call may be a NOOP, if the model was pre-trained and expects
# different weights for the normalisation layer.
if hasattr(model, "set_normalizer"):
model.set_normalizer(datamodule.unshuffled_train_dataloader())
else:
logger.warning(
f"Model {model.name} has no 'set_normalizer' method. Skipping."
)
# If asked, rebalances the loss criterion based on the relative proportion # If asked, rebalances the loss criterion based on the relative proportion
# of class examples available in the training set. Also affects the # of class examples available in the training set. Also affects the
# validation loss if a validation set is available on the data module. # validation loss if a validation set is available on the data module.
...@@ -276,9 +266,23 @@ def train( ...@@ -276,9 +266,23 @@ def train(
) )
logger.info(f"Training for at most {epochs} epochs.") logger.info(f"Training for at most {epochs} epochs.")
# We only load the checkpoint to get some information about its state. The
# actual loading of the model is done in trainer.fit() arguments = {}
if checkpoint_file is not None: arguments["max_epoch"] = epochs
arguments["epoch"] = 0
if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"):
# Sets the model normalizer with the unaugmented-train-subset.
# this call may be a NOOP, if the model was pre-trained and expects
# different weights for the normalisation layer.
if hasattr(model, "set_normalizer"):
model.set_normalizer(datamodule.unshuffled_train_dataloader())
else:
logger.warning(
f"Model {model.name} has no 'set_normalizer' method. Skipping."
)
else:
# Normalizer will be loaded during model.on_load_checkpoint
checkpoint = torch.load(checkpoint_file) checkpoint = torch.load(checkpoint_file)
start_epoch = checkpoint["epoch"] start_epoch = checkpoint["epoch"]
logger.info(f"Resuming from epoch {start_epoch}...") logger.info(f"Resuming from epoch {start_epoch}...")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment