diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 9b743d64cdeae69cb50e65fedbe169325f353764..2ef766421b1ffe34d6bbfe9caee60db5f00d256f 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -256,7 +256,12 @@ def train( # Sets the model normalizer with the unaugmented-train-subset. # this call may be a NOOP, if the model was pre-trained and expects # different weights for the normalisation layer. - model.set_normalizer(datamodule.train_dataloader()) + if hasattr(model, "set_normalizer"): + model.set_normalizer(datamodule.train_dataloader()) + else: + logger.info( + f"Model {model.name} has no 'set_normalizer' method. No normalization will be applied." + ) # Rebalances the loss criterion based on the relative proportion of class # examples available in the training set. Also affects the validation loss