From 9eb1e0235b935e4fe7bee02a04494482ccc9dcea Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 3 Jul 2023 14:56:55 +0200 Subject: [PATCH] Check if set_normalizer method is defined in model --- src/ptbench/scripts/train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 9b743d64..2ef76642 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -256,7 +256,12 @@ def train( # Sets the model normalizer with the unaugmented-train-subset. # this call may be a NOOP, if the model was pre-trained and expects # different weights for the normalisation layer. - model.set_normalizer(datamodule.train_dataloader()) + if hasattr(model, "set_normalizer"): + model.set_normalizer(datamodule.train_dataloader()) + else: + logger.info( + f"Model {model.name} has no 'set_normalizer' method. No normalization will be applied." + ) # Rebalances the loss criterion based on the relative proportion of class # examples available in the training set. Also affects the validation loss -- GitLab