Skip to content
Snippets Groups Projects
Commit 9eb1e023 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Check if set_normalizer method is defined in model

parent 5bb7387f
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -256,7 +256,12 @@ def train( ...@@ -256,7 +256,12 @@ def train(
# Sets the model normalizer with the unaugmented-train-subset. # Sets the model normalizer with the unaugmented-train-subset.
# this call may be a NOOP, if the model was pre-trained and expects # this call may be a NOOP, if the model was pre-trained and expects
# different weights for the normalisation layer. # different weights for the normalisation layer.
model.set_normalizer(datamodule.train_dataloader()) if hasattr(model, "set_normalizer"):
model.set_normalizer(datamodule.train_dataloader())
else:
logger.info(
f"Model {model.name} has no 'set_normalizer' method. No normalization will be applied."
)
# Rebalances the loss criterion based on the relative proportion of class # Rebalances the loss criterion based on the relative proportion of class
# examples available in the training set. Also affects the validation loss # examples available in the training set. Also affects the validation loss
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment