Skip to content
Snippets Groups Projects
Commit 7aaec166 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[normalizer] Ensure targets are int type

parent 6dc6aa6a
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -3,12 +3,16 @@
# SPDX-License-Identifier: GPL-3.0-or-later
"""Functions to compute normalisation factors based on dataloaders."""
import logging
import torch
import torch.nn
import torch.utils.data
import torchvision.transforms
import tqdm
logger = logging.getLogger("mednet")
def make_z_normalizer(
dataloader: torch.utils.data.DataLoader,
......@@ -31,6 +35,19 @@ def make_z_normalizer(
# Peek the number of channels of batches in the data loader
batch = next(iter(dataloader))
# Ensure targets are ints
try:
target = batch[1]["label"][0].item()
if not isinstance(target, int):
logger.info(
"Targets are not Integer type, skipping z-normalization."
)
return None
except RuntimeError:
logger.info("Targets are not Integer type, skipping z-normalization.")
return None
channels = batch[0].shape[1]
# Initialises accumulators
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment