From 7aaec166ef1c3833863f38dd8104d4f028bce239 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 8 May 2024 11:26:00 +0200 Subject: [PATCH] [normalizer] Ensure targets are int type --- src/mednet/libs/common/models/normalizer.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/mednet/libs/common/models/normalizer.py b/src/mednet/libs/common/models/normalizer.py index fc2992c5..63a5f10c 100644 --- a/src/mednet/libs/common/models/normalizer.py +++ b/src/mednet/libs/common/models/normalizer.py @@ -3,12 +3,16 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Functions to compute normalisation factors based on dataloaders.""" +import logging + import torch import torch.nn import torch.utils.data import torchvision.transforms import tqdm +logger = logging.getLogger("mednet") + def make_z_normalizer( dataloader: torch.utils.data.DataLoader, @@ -31,6 +35,19 @@ def make_z_normalizer( # Peek the number of channels of batches in the data loader batch = next(iter(dataloader)) + + # Ensure targets are ints + try: + target = batch[1]["label"][0].item() + if not isinstance(target, int): + logger.info( + "Targets are not Integer type, skipping z-normalization." + ) + return None + except RuntimeError: + logger.info("Targets are not Integer type, skipping z-normalization.") + return None + channels = batch[0].shape[1] # Initialises accumulators -- GitLab