Skip to content
Snippets Groups Projects
Commit bd46b6bd authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Check if set_normalizer method is defined in model

parent 5f28a673
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
......@@ -256,7 +256,12 @@ def train(
# Sets the model normalizer with the unaugmented-train-subset.
# this call may be a NOOP, if the model was pre-trained and expects
# different weights for the normalisation layer.
model.set_normalizer(datamodule.train_dataloader())
if hasattr(model, "set_normalizer"):
model.set_normalizer(datamodule.train_dataloader())
else:
logger.info(
f"Model {model.name} has no 'set_normalizer' method. No normalization will be applied."
)
# Rebalances the loss criterion based on the relative proportion of class
# examples available in the training set. Also affects the validation loss
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment