diff --git a/src/mednet/libs/common/models/normalizer.py b/src/mednet/libs/classification/models/normalizer.py similarity index 100% rename from src/mednet/libs/common/models/normalizer.py rename to src/mednet/libs/classification/models/normalizer.py diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py index 0b5c24da02b73f1328651612e216cb731b90bbf7..313e1b54eaddc2d455c20d45f5ff6fb0d183ca1a 100644 --- a/src/mednet/libs/classification/models/pasa.py +++ b/src/mednet/libs/classification/models/pasa.py @@ -192,6 +192,23 @@ class Pasa(Model): # x = F.log_softmax(x, dim=1) # 0 is batch size + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initialize the input normalizer for the current model. + + Parameters + ---------- + dataloader + A torch Dataloader from which to compute the mean and std. + """ + + from .normalizer import make_z_normalizer + + logger.info( + f"Uninitialised {self.name} model - " + f"computing z-norm factors from train dataloader.", + ) + self.normalizer = make_z_normalizer(dataloader) + def training_step(self, batch, _): images = batch[0] labels = batch[1]["target"] diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py index 4dc98b1da5e049bf541f751256428c912e8b8a32..cc9d5d84dee0a571d7809e4dfa1832f429c68de3 100644 --- a/src/mednet/libs/common/models/model.py +++ b/src/mednet/libs/common/models/model.py @@ -130,21 +130,7 @@ class Model(pl.LightningModule): self.normalizer = checkpoint["normalizer"] def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: - """Initialize the input normalizer for the current model. - - Parameters - ---------- - dataloader - A torch Dataloader from which to compute the mean and std. - """ - - from .normalizer import make_z_normalizer - - logger.info( - f"Uninitialised {self.name} model - " - f"computing z-norm factors from train dataloader.", - ) - self.normalizer = make_z_normalizer(dataloader) + raise NotImplementedError def training_step(self, batch, _): raise NotImplementedError diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py index 06014c81efc338b24b3a30c06e614d5b36281f1c..176293431877ed2dcf6e124c90faf4de0f7944a6 100644 --- a/src/mednet/libs/segmentation/config/data/drive/datamodule.py +++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py @@ -57,11 +57,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None) ) - tensor = tv_tensors.Image(crop_image_to_mask(image, mask)) - target = tv_tensors.Image(crop_image_to_mask(target, mask)) + image = tv_tensors.Image(crop_image_to_mask(image, mask)) + target = tv_tensors.Mask(crop_image_to_mask(target, mask)) mask = tv_tensors.Mask(crop_image_to_mask(mask, mask)) - return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type] + return dict(image=image, target=target, mask=mask), dict(name=sample[0]) # type: ignore[arg-type] class DataModule(CachingDataModule): diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py index eac0afc03965519ac1af65127ac5565f31151e2a..f3a2e0d2d38866f670f051e6b1ea685ab2c3fd1f 100644 --- a/src/mednet/libs/segmentation/models/driu.py +++ b/src/mednet/libs/segmentation/models/driu.py @@ -149,7 +149,7 @@ class DRIU(Model): Will not be used if the model is pretrained. """ if self.pretrained: - from mednet.libs.common.models.normalizer import make_imagenet_normalizer + from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py index 7804ed06747d22cd947f4a1ba46d657d9e36c64d..3bb93ba9ee9221938a65af1c2d6fd62a1104d55d 100644 --- a/src/mednet/libs/segmentation/models/driu_bn.py +++ b/src/mednet/libs/segmentation/models/driu_bn.py @@ -152,7 +152,7 @@ class DRIUBN(Model): Will not be used if the model is pretrained. """ if self.pretrained: - from mednet.libs.common.models.normalizer import make_imagenet_normalizer + from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py index be7416c7f34e89a6c826f4b87f0dbafe5c4e1a99..98e596236bddf19fcc1789f136f1e162ed76ecf3 100644 --- a/src/mednet/libs/segmentation/models/driu_od.py +++ b/src/mednet/libs/segmentation/models/driu_od.py @@ -134,7 +134,7 @@ class DRIUOD(Model): Will not be used if the model is pretrained. """ if self.pretrained: - from mednet.libs.common.models.normalizer import make_imagenet_normalizer + from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py index 23218a577a34367b090a51dbd7340366f67dbe5c..6846da5a8f2700ecc4f2370cd485875016c1749b 100644 --- a/src/mednet/libs/segmentation/models/driu_pix.py +++ b/src/mednet/libs/segmentation/models/driu_pix.py @@ -138,7 +138,7 @@ class DRIUPix(Model): Will not be used if the model is pretrained. """ if self.pretrained: - from mednet.libs.common.models.normalizer import make_imagenet_normalizer + from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py index 8771f55813aab08ccbe853d151f6f28c26ba816b..7e0b770513d0f8d5d94515827160f029cbf3346c 100644 --- a/src/mednet/libs/segmentation/models/hed.py +++ b/src/mednet/libs/segmentation/models/hed.py @@ -153,7 +153,7 @@ class HED(Model): Will not be used if the model is pretrained. """ if self.pretrained: - from mednet.libs.common.models.normalizer import make_imagenet_normalizer + from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index 90507e4dcd513e9f0a8d12a368f08a010c9ff77c..28bdf498fd74705546bf39a57835b378e4144f4e 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -15,6 +15,7 @@ guide segmentation. Reference: [GALDRAN-2020]_ """ +import logging import typing import torch @@ -23,6 +24,8 @@ from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.models.model import Model from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss +logger = logging.getLogger("mednet") + def _conv1x1(in_planes, out_planes, stride=1): return torch.nn.Conv2d( @@ -338,6 +341,23 @@ class LittleWNet(Model): shortcut=True, ) + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initialize the input normalizer for the current model. + + Parameters + ---------- + dataloader + A torch Dataloader from which to compute the mean and std. + """ + + from .normalizer import make_z_normalizer + + logger.info( + f"Uninitialised {self.name} model - " + f"computing z-norm factors from train dataloader.", + ) + self.normalizer = make_z_normalizer(dataloader) + def forward(self, x): xn = self.normalizer(x) x1 = self.unet1(xn) diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py index 60f97967ef96b25c87afee5570cc81a8d7cab90c..b371540945d8e8b3c481c84abf7ea6c3309995fa 100644 --- a/src/mednet/libs/segmentation/models/m2unet.py +++ b/src/mednet/libs/segmentation/models/m2unet.py @@ -201,7 +201,7 @@ class M2UNET(Model): Will not be used if the model is pretrained. """ if self.pretrained: - from mednet.libs.common.models.normalizer import make_imagenet_normalizer + from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT " diff --git a/src/mednet/libs/segmentation/models/normalizer.py b/src/mednet/libs/segmentation/models/normalizer.py new file mode 100644 index 0000000000000000000000000000000000000000..df630bd1d61c9b7e10a951a80e8f560248f9a477 --- /dev/null +++ b/src/mednet/libs/segmentation/models/normalizer.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Functions to compute normalisation factors based on dataloaders.""" + +import logging + +import torch +import torch.nn +import torch.utils.data +import torchvision.transforms +import tqdm + +logger = logging.getLogger("mednet") + + +def make_z_normalizer( + dataloader: torch.utils.data.DataLoader, +) -> torchvision.transforms.Normalize: + """Compute mean and standard deviation from a dataloader. + + This function will input a dataloader, and compute the mean and standard + deviation by image channel. It will work for both monochromatic, and color + inputs with 2, 3 or more color planes. + + Parameters + ---------- + dataloader + A torch Dataloader from which to compute the mean and std. + + Returns + ------- + An initialized normalizer. + """ + + # Peek the number of channels of batches in the data loader + batch = next(iter(dataloader)) + channels = batch[0]["image"].shape[1] + + # Initialises accumulators + mean = torch.zeros(channels, dtype=batch[0]["image"].dtype) + var = torch.zeros(channels, dtype=batch[0]["image"].dtype) + num_images = 0 + + # Evaluates mean and standard deviation + for batch in tqdm.tqdm(dataloader, unit="batch"): + data = batch[0] + data = data.view(data.size(0), data.size(1), -1) + + num_images += data.size(0) + mean += data.mean(2).sum(0) + var += data.var(2).sum(0) + + mean /= num_images + var /= num_images + std = torch.sqrt(var) + + return torchvision.transforms.Normalize(mean, std) + + +def make_imagenet_normalizer() -> torchvision.transforms.Normalize: + """Return the stock ImageNet normalisation weights from torchvision. + + The weights are wrapped in a torch module. This normalizer only works for + **RGB (color) images**. + + Returns + ------- + An initialized normalizer. + """ + + return torchvision.transforms.Normalize( + (0.485, 0.456, 0.406), + (0.229, 0.224, 0.225), + ) diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py index 7a3ee9bbccde07ef897e1cf03ad3e4efeb44cf17..98578d1d437bbe53fbab48f31764a1580d5ad5eb 100644 --- a/src/mednet/libs/segmentation/models/unet.py +++ b/src/mednet/libs/segmentation/models/unet.py @@ -142,7 +142,7 @@ class Unet(Model): Will not be used if the model is pretrained. """ if self.pretrained: - from mednet.libs.common.models.normalizer import make_imagenet_normalizer + from .normalizer import make_imagenet_normalizer logger.warning( f"ImageNet pre-trained {self.name} model - NOT "