From da13b5b0f46f9d1aa42f948996f2b94541d4033b Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 26 Jun 2024 10:14:23 +0200 Subject: [PATCH] [normalizer] Move normalizer out of common package and inside libs --- .../models/normalizer.py | 0 src/mednet/libs/classification/models/pasa.py | 17 +++++ src/mednet/libs/common/models/model.py | 16 +--- .../config/data/drive/datamodule.py | 6 +- src/mednet/libs/segmentation/models/driu.py | 2 +- .../libs/segmentation/models/driu_bn.py | 2 +- .../libs/segmentation/models/driu_od.py | 2 +- .../libs/segmentation/models/driu_pix.py | 2 +- src/mednet/libs/segmentation/models/hed.py | 2 +- src/mednet/libs/segmentation/models/lwnet.py | 20 +++++ src/mednet/libs/segmentation/models/m2unet.py | 2 +- .../libs/segmentation/models/normalizer.py | 75 +++++++++++++++++++ src/mednet/libs/segmentation/models/unet.py | 2 +- 13 files changed, 123 insertions(+), 25 deletions(-) rename src/mednet/libs/{common => classification}/models/normalizer.py (100%) create mode 100644 src/mednet/libs/segmentation/models/normalizer.py diff --git a/src/mednet/libs/common/models/normalizer.py b/src/mednet/libs/classification/models/normalizer.py similarity index 100% rename from src/mednet/libs/common/models/normalizer.py rename to src/mednet/libs/classification/models/normalizer.py diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py index 0b5c24da..313e1b54 100644 --- a/src/mednet/libs/classification/models/pasa.py +++ b/src/mednet/libs/classification/models/pasa.py @@ -192,6 +192,23 @@ class Pasa(Model): # x = F.log_softmax(x, dim=1) # 0 is batch size + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initialize the input normalizer for the current model. + + Parameters + ---------- + dataloader + A torch Dataloader from which to compute the mean and std. + """ + + from .normalizer import make_z_normalizer + + logger.info( + f"Uninitialised {self.name} model - " + f"computing z-norm factors from train dataloader.", + ) + self.normalizer = make_z_normalizer(dataloader) + def training_step(self, batch, _): images = batch[0] labels = batch[1]["target"] diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py index 4dc98b1d..cc9d5d84 100644 --- a/src/mednet/libs/common/models/model.py +++ b/src/mednet/libs/common/models/model.py @@ -130,21 +130,7 @@ class Model(pl.LightningModule): self.normalizer = checkpoint["normalizer"] def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: - """Initialize the input normalizer for the current model. - - Parameters - ---------- - dataloader - A torch Dataloader from which to compute the mean and std. - """ - - from .normalizer import make_z_normalizer - - logger.info( - f"Uninitialised {self.name} model - " - f"computing z-norm factors from train dataloader.", - ) - self.normalizer = make_z_normalizer(dataloader) + raise NotImplementedError def training_step(self, batch, _): raise NotImplementedError diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py index 06014c81..17629343 100644 --- a/src/mednet/libs/segmentation/config/data/drive/datamodule.py +++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py @@ -57,11 +57,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None) ) - tensor = tv_tensors.Image(crop_image_to_mask(image, mask)) - target = tv_tensors.Image(crop_image_to_mask(target, mask)) + image = tv_tensors.Image(crop_image_to_mask(image, mask)) + target = tv_tensors.Mask(crop_image_to_mask(target, mask)) mask = tv_tensors.Mask(crop_image_to_mask(mask, mask)) - return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type] + return dict(image=image, target=target, mask=mask), dict(name=sample[0]) # type: ignore[arg-type] class DataModule(CachingDataModule): diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py index eac0afc0..f3a2e0d2 100644 --- a/src/mednet/libs/segmentation/models/driu.py +++ b/src/mednet/libs/segmentation/models/driu.py @@ -149,7 +149,7 @@ class DRIU(Model): Will not be used if the model is pretrained. """ if self.pretrained: - from mednet.libs.common.models.normalizer import make_imagenet_normalizer + from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py index 7804ed06..3bb93ba9 100644 --- a/src/mednet/libs/segmentation/models/driu_bn.py +++ b/src/mednet/libs/segmentation/models/driu_bn.py @@ -152,7 +152,7 @@ class DRIUBN(Model): Will not be used if the model is pretrained. """ if self.pretrained: - from mednet.libs.common.models.normalizer import make_imagenet_normalizer + from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py index be7416c7..98e59623 100644 --- a/src/mednet/libs/segmentation/models/driu_od.py +++ b/src/mednet/libs/segmentation/models/driu_od.py @@ -134,7 +134,7 @@ class DRIUOD(Model): Will not be used if the model is pretrained. """ if self.pretrained: - from mednet.libs.common.models.normalizer import make_imagenet_normalizer + from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py index 23218a57..6846da5a 100644 --- a/src/mednet/libs/segmentation/models/driu_pix.py +++ b/src/mednet/libs/segmentation/models/driu_pix.py @@ -138,7 +138,7 @@ class DRIUPix(Model): Will not be used if the model is pretrained. """ if self.pretrained: - from mednet.libs.common.models.normalizer import make_imagenet_normalizer + from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py index 8771f558..7e0b7705 100644 --- a/src/mednet/libs/segmentation/models/hed.py +++ b/src/mednet/libs/segmentation/models/hed.py @@ -153,7 +153,7 @@ class HED(Model): Will not be used if the model is pretrained. """ if self.pretrained: - from mednet.libs.common.models.normalizer import make_imagenet_normalizer + from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index 90507e4d..28bdf498 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -15,6 +15,7 @@ guide segmentation. Reference: [GALDRAN-2020]_ """ +import logging import typing import torch @@ -23,6 +24,8 @@ from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.models.model import Model from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss +logger = logging.getLogger("mednet") + def _conv1x1(in_planes, out_planes, stride=1): return torch.nn.Conv2d( @@ -338,6 +341,23 @@ class LittleWNet(Model): shortcut=True, ) + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initialize the input normalizer for the current model. + + Parameters + ---------- + dataloader + A torch Dataloader from which to compute the mean and std. + """ + + from .normalizer import make_z_normalizer + + logger.info( + f"Uninitialised {self.name} model - " + f"computing z-norm factors from train dataloader.", + ) + self.normalizer = make_z_normalizer(dataloader) + def forward(self, x): xn = self.normalizer(x) x1 = self.unet1(xn) diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py index 60f97967..b3715409 100644 --- a/src/mednet/libs/segmentation/models/m2unet.py +++ b/src/mednet/libs/segmentation/models/m2unet.py @@ -201,7 +201,7 @@ class M2UNET(Model): Will not be used if the model is pretrained. """ if self.pretrained: - from mednet.libs.common.models.normalizer import make_imagenet_normalizer + from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " diff --git a/src/mednet/libs/segmentation/models/normalizer.py b/src/mednet/libs/segmentation/models/normalizer.py new file mode 100644 index 00000000..df630bd1 --- /dev/null +++ b/src/mednet/libs/segmentation/models/normalizer.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Functions to compute normalisation factors based on dataloaders.""" + +import logging + +import torch +import torch.nn +import torch.utils.data +import torchvision.transforms +import tqdm + +logger = logging.getLogger("mednet") + + +def make_z_normalizer( + dataloader: torch.utils.data.DataLoader, +) -> torchvision.transforms.Normalize: + """Compute mean and standard deviation from a dataloader. + + This function will input a dataloader, and compute the mean and standard + deviation by image channel. It will work for both monochromatic, and color + inputs with 2, 3 or more color planes. + + Parameters + ---------- + dataloader + A torch Dataloader from which to compute the mean and std. + + Returns + ------- + An initialized normalizer. + """ + + # Peek the number of channels of batches in the data loader + batch = next(iter(dataloader)) + channels = batch[0]["image"].shape[1] + + # Initialises accumulators + mean = torch.zeros(channels, dtype=batch[0]["image"].dtype) + var = torch.zeros(channels, dtype=batch[0]["image"].dtype) + num_images = 0 + + # Evaluates mean and standard deviation + for batch in tqdm.tqdm(dataloader, unit="batch"): + data = batch[0] + data = data.view(data.size(0), data.size(1), -1) + + num_images += data.size(0) + mean += data.mean(2).sum(0) + var += data.var(2).sum(0) + + mean /= num_images + var /= num_images + std = torch.sqrt(var) + + return torchvision.transforms.Normalize(mean, std) + + +def make_imagenet_normalizer() -> torchvision.transforms.Normalize: + """Return the stock ImageNet normalisation weights from torchvision. + + The weights are wrapped in a torch module. This normalizer only works for + **RGB (color) images**. + + Returns + ------- + An initialized normalizer. + """ + + return torchvision.transforms.Normalize( + (0.485, 0.456, 0.406), + (0.229, 0.224, 0.225), + ) diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py index 7a3ee9bb..98578d1d 100644 --- a/src/mednet/libs/segmentation/models/unet.py +++ b/src/mednet/libs/segmentation/models/unet.py @@ -142,7 +142,7 @@ class Unet(Model): Will not be used if the model is pretrained. """ if self.pretrained: - from mednet.libs.common.models.normalizer import make_imagenet_normalizer + from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " -- GitLab