From 9eb1e0235b935e4fe7bee02a04494482ccc9dcea Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 3 Jul 2023 14:56:55 +0200
Subject: [PATCH] Check if set_normalizer method is defined in model

---
 src/ptbench/scripts/train.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 9b743d64..2ef76642 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -256,7 +256,12 @@ def train(
     # Sets the model normalizer with the unaugmented-train-subset.
     # this call may be a NOOP, if the model was pre-trained and expects
     # different weights for the normalisation layer.
-    model.set_normalizer(datamodule.train_dataloader())
+    if hasattr(model, "set_normalizer"):
+        model.set_normalizer(datamodule.train_dataloader())
+    else:
+        logger.info(
+            f"Model {model.name} has no 'set_normalizer' method. No normalization will be applied."
+        )
 
     # Rebalances the loss criterion based on the relative proportion of class
     # examples available in the training set.  Also affects the validation loss
-- 
GitLab