From 7aaec166ef1c3833863f38dd8104d4f028bce239 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 8 May 2024 11:26:00 +0200
Subject: [PATCH] [normalizer] Ensure targets are int type

---
 src/mednet/libs/common/models/normalizer.py | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)

diff --git a/src/mednet/libs/common/models/normalizer.py b/src/mednet/libs/common/models/normalizer.py
index fc2992c5..63a5f10c 100644
--- a/src/mednet/libs/common/models/normalizer.py
+++ b/src/mednet/libs/common/models/normalizer.py
@@ -3,12 +3,16 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 """Functions to compute normalisation factors based on dataloaders."""
 
+import logging
+
 import torch
 import torch.nn
 import torch.utils.data
 import torchvision.transforms
 import tqdm
 
+logger = logging.getLogger("mednet")
+
 
 def make_z_normalizer(
     dataloader: torch.utils.data.DataLoader,
@@ -31,6 +35,19 @@ def make_z_normalizer(
 
     # Peek the number of channels of batches in the data loader
     batch = next(iter(dataloader))
+
+    # Ensure targets are ints
+    try:
+        target = batch[1]["label"][0].item()
+        if not isinstance(target, int):
+            logger.info(
+                "Targets are not Integer type, skipping z-normalization."
+            )
+            return None
+    except RuntimeError:
+        logger.info("Targets are not Integer type, skipping z-normalization.")
+        return None
+
     channels = batch[0].shape[1]
 
     # Initialises accumulators
-- 
GitLab