Skip to content
Snippets Groups Projects
Commit edadec6d authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Save and restore normalizer from checkpoint

parent 2fbbd899
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -185,6 +185,13 @@ class Pasa(pl.LightningModule):
return x
def on_save_checkpoint(self, checkpoint):
checkpoint["normalizer"] = self.normalizer
def on_load_checkpoint(self, checkpoint):
logger.info("Restoring normalizer from checkpoint.")
self.normalizer = checkpoint["normalizer"]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initializes the input normalizer for the current model.
......
......@@ -252,16 +252,6 @@ def train(
datamodule.prepare_data()
datamodule.setup(stage="fit")
# Sets the model normalizer with the unaugmented-train-subset.
# this call may be a NOOP, if the model was pre-trained and expects
# different weights for the normalisation layer.
if hasattr(model, "set_normalizer"):
model.set_normalizer(datamodule.unshuffled_train_dataloader())
else:
logger.warning(
f"Model {model.name} has no 'set_normalizer' method. Skipping."
)
# If asked, rebalances the loss criterion based on the relative proportion
# of class examples available in the training set. Also affects the
# validation loss if a validation set is available on the data module.
......@@ -276,9 +266,23 @@ def train(
)
logger.info(f"Training for at most {epochs} epochs.")
# We only load the checkpoint to get some information about its state. The
# actual loading of the model is done in trainer.fit()
if checkpoint_file is not None:
arguments = {}
arguments["max_epoch"] = epochs
arguments["epoch"] = 0
if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"):
# Sets the model normalizer with the unaugmented-train-subset.
# this call may be a NOOP, if the model was pre-trained and expects
# different weights for the normalisation layer.
if hasattr(model, "set_normalizer"):
model.set_normalizer(datamodule.unshuffled_train_dataloader())
else:
logger.warning(
f"Model {model.name} has no 'set_normalizer' method. Skipping."
)
else:
# Normalizer will be loaded during model.on_load_checkpoint
checkpoint = torch.load(checkpoint_file)
start_epoch = checkpoint["epoch"]
logger.info(f"Resuming from epoch {start_epoch}...")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment