diff --git a/src/mednet/libs/common/models/normalizer.py b/src/mednet/libs/common/models/normalizer.py index fc2992c57bb336ec9625152e4054211293cfee25..63a5f10c67f369a34e1f7ad04ef92b3ee649b2fb 100644 --- a/src/mednet/libs/common/models/normalizer.py +++ b/src/mednet/libs/common/models/normalizer.py @@ -3,12 +3,16 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Functions to compute normalisation factors based on dataloaders.""" +import logging + import torch import torch.nn import torch.utils.data import torchvision.transforms import tqdm +logger = logging.getLogger("mednet") + def make_z_normalizer( dataloader: torch.utils.data.DataLoader, @@ -31,6 +35,19 @@ def make_z_normalizer( # Peek the number of channels of batches in the data loader batch = next(iter(dataloader)) + + # Ensure targets are ints + try: + target = batch[1]["label"][0].item() + if not isinstance(target, int): + logger.info( + "Targets are not Integer type, skipping z-normalization." + ) + return None + except RuntimeError: + logger.info("Targets are not Integer type, skipping z-normalization.") + return None + channels = batch[0].shape[1] # Initialises accumulators