Skip to content
Snippets Groups Projects
Commit 7aaec166 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[normalizer] Ensure targets are int type

parent 6dc6aa6a
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -3,12 +3,16 @@ ...@@ -3,12 +3,16 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Functions to compute normalisation factors based on dataloaders.""" """Functions to compute normalisation factors based on dataloaders."""
import logging
import torch import torch
import torch.nn import torch.nn
import torch.utils.data import torch.utils.data
import torchvision.transforms import torchvision.transforms
import tqdm import tqdm
logger = logging.getLogger("mednet")
def make_z_normalizer( def make_z_normalizer(
dataloader: torch.utils.data.DataLoader, dataloader: torch.utils.data.DataLoader,
...@@ -31,6 +35,19 @@ def make_z_normalizer( ...@@ -31,6 +35,19 @@ def make_z_normalizer(
# Peek the number of channels of batches in the data loader # Peek the number of channels of batches in the data loader
batch = next(iter(dataloader)) batch = next(iter(dataloader))
# Ensure targets are ints
try:
target = batch[1]["label"][0].item()
if not isinstance(target, int):
logger.info(
"Targets are not Integer type, skipping z-normalization."
)
return None
except RuntimeError:
logger.info("Targets are not Integer type, skipping z-normalization.")
return None
channels = batch[0].shape[1] channels = batch[0].shape[1]
# Initialises accumulators # Initialises accumulators
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment