From da13b5b0f46f9d1aa42f948996f2b94541d4033b Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 26 Jun 2024 10:14:23 +0200
Subject: [PATCH] [normalizer] Move normalizer out of common package and inside
 libs

---
 .../models/normalizer.py                      |  0
 src/mednet/libs/classification/models/pasa.py | 17 +++++
 src/mednet/libs/common/models/model.py        | 16 +---
 .../config/data/drive/datamodule.py           |  6 +-
 src/mednet/libs/segmentation/models/driu.py   |  2 +-
 .../libs/segmentation/models/driu_bn.py       |  2 +-
 .../libs/segmentation/models/driu_od.py       |  2 +-
 .../libs/segmentation/models/driu_pix.py      |  2 +-
 src/mednet/libs/segmentation/models/hed.py    |  2 +-
 src/mednet/libs/segmentation/models/lwnet.py  | 20 +++++
 src/mednet/libs/segmentation/models/m2unet.py |  2 +-
 .../libs/segmentation/models/normalizer.py    | 75 +++++++++++++++++++
 src/mednet/libs/segmentation/models/unet.py   |  2 +-
 13 files changed, 123 insertions(+), 25 deletions(-)
 rename src/mednet/libs/{common => classification}/models/normalizer.py (100%)
 create mode 100644 src/mednet/libs/segmentation/models/normalizer.py

diff --git a/src/mednet/libs/common/models/normalizer.py b/src/mednet/libs/classification/models/normalizer.py
similarity index 100%
rename from src/mednet/libs/common/models/normalizer.py
rename to src/mednet/libs/classification/models/normalizer.py
diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py
index 0b5c24da..313e1b54 100644
--- a/src/mednet/libs/classification/models/pasa.py
+++ b/src/mednet/libs/classification/models/pasa.py
@@ -192,6 +192,23 @@ class Pasa(Model):
 
         # x = F.log_softmax(x, dim=1) # 0 is batch size
 
+    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
+        """Initialize the input normalizer for the current model.
+
+        Parameters
+        ----------
+        dataloader
+            A torch Dataloader from which to compute the mean and std.
+        """
+
+        from .normalizer import make_z_normalizer
+
+        logger.info(
+            f"Uninitialised {self.name} model - "
+            f"computing z-norm factors from train dataloader.",
+        )
+        self.normalizer = make_z_normalizer(dataloader)
+
     def training_step(self, batch, _):
         images = batch[0]
         labels = batch[1]["target"]
diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py
index 4dc98b1d..cc9d5d84 100644
--- a/src/mednet/libs/common/models/model.py
+++ b/src/mednet/libs/common/models/model.py
@@ -130,21 +130,7 @@ class Model(pl.LightningModule):
         self.normalizer = checkpoint["normalizer"]
 
     def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
-        """Initialize the input normalizer for the current model.
-
-        Parameters
-        ----------
-        dataloader
-            A torch Dataloader from which to compute the mean and std.
-        """
-
-        from .normalizer import make_z_normalizer
-
-        logger.info(
-            f"Uninitialised {self.name} model - "
-            f"computing z-norm factors from train dataloader.",
-        )
-        self.normalizer = make_z_normalizer(dataloader)
+        raise NotImplementedError
 
     def training_step(self, batch, _):
         raise NotImplementedError
diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
index 06014c81..17629343 100644
--- a/src/mednet/libs/segmentation/config/data/drive/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
@@ -57,11 +57,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py
index eac0afc0..f3a2e0d2 100644
--- a/src/mednet/libs/segmentation/models/driu.py
+++ b/src/mednet/libs/segmentation/models/driu.py
@@ -149,7 +149,7 @@ class DRIU(Model):
             Will not be used if the model is pretrained.
         """
         if self.pretrained:
-            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+            from .normalizer import make_imagenet_normalizer
 
             logger.warning(
                 f"ImageNet pre-trained {self.name} model - NOT "
diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py
index 7804ed06..3bb93ba9 100644
--- a/src/mednet/libs/segmentation/models/driu_bn.py
+++ b/src/mednet/libs/segmentation/models/driu_bn.py
@@ -152,7 +152,7 @@ class DRIUBN(Model):
             Will not be used if the model is pretrained.
         """
         if self.pretrained:
-            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+            from .normalizer import make_imagenet_normalizer
 
             logger.warning(
                 f"ImageNet pre-trained {self.name} model - NOT "
diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py
index be7416c7..98e59623 100644
--- a/src/mednet/libs/segmentation/models/driu_od.py
+++ b/src/mednet/libs/segmentation/models/driu_od.py
@@ -134,7 +134,7 @@ class DRIUOD(Model):
             Will not be used if the model is pretrained.
         """
         if self.pretrained:
-            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+            from .normalizer import make_imagenet_normalizer
 
             logger.warning(
                 f"ImageNet pre-trained {self.name} model - NOT "
diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py
index 23218a57..6846da5a 100644
--- a/src/mednet/libs/segmentation/models/driu_pix.py
+++ b/src/mednet/libs/segmentation/models/driu_pix.py
@@ -138,7 +138,7 @@ class DRIUPix(Model):
             Will not be used if the model is pretrained.
         """
         if self.pretrained:
-            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+            from .normalizer import make_imagenet_normalizer
 
             logger.warning(
                 f"ImageNet pre-trained {self.name} model - NOT "
diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py
index 8771f558..7e0b7705 100644
--- a/src/mednet/libs/segmentation/models/hed.py
+++ b/src/mednet/libs/segmentation/models/hed.py
@@ -153,7 +153,7 @@ class HED(Model):
             Will not be used if the model is pretrained.
         """
         if self.pretrained:
-            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+            from .normalizer import make_imagenet_normalizer
 
             logger.warning(
                 f"ImageNet pre-trained {self.name} model - NOT "
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index 90507e4d..28bdf498 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -15,6 +15,7 @@ guide segmentation.
 Reference: [GALDRAN-2020]_
 """
 
+import logging
 import typing
 
 import torch
@@ -23,6 +24,8 @@ from mednet.libs.common.data.typing import TransformSequence
 from mednet.libs.common.models.model import Model
 from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
 
+logger = logging.getLogger("mednet")
+
 
 def _conv1x1(in_planes, out_planes, stride=1):
     return torch.nn.Conv2d(
@@ -338,6 +341,23 @@ class LittleWNet(Model):
             shortcut=True,
         )
 
+    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
+        """Initialize the input normalizer for the current model.
+
+        Parameters
+        ----------
+        dataloader
+            A torch Dataloader from which to compute the mean and std.
+        """
+
+        from .normalizer import make_z_normalizer
+
+        logger.info(
+            f"Uninitialised {self.name} model - "
+            f"computing z-norm factors from train dataloader.",
+        )
+        self.normalizer = make_z_normalizer(dataloader)
+
     def forward(self, x):
         xn = self.normalizer(x)
         x1 = self.unet1(xn)
diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py
index 60f97967..b3715409 100644
--- a/src/mednet/libs/segmentation/models/m2unet.py
+++ b/src/mednet/libs/segmentation/models/m2unet.py
@@ -201,7 +201,7 @@ class M2UNET(Model):
             Will not be used if the model is pretrained.
         """
         if self.pretrained:
-            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+            from .normalizer import make_imagenet_normalizer
 
             logger.warning(
                 f"ImageNet pre-trained {self.name} model - NOT "
diff --git a/src/mednet/libs/segmentation/models/normalizer.py b/src/mednet/libs/segmentation/models/normalizer.py
new file mode 100644
index 00000000..df630bd1
--- /dev/null
+++ b/src/mednet/libs/segmentation/models/normalizer.py
@@ -0,0 +1,75 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""Functions to compute normalisation factors based on dataloaders."""
+
+import logging
+
+import torch
+import torch.nn
+import torch.utils.data
+import torchvision.transforms
+import tqdm
+
+logger = logging.getLogger("mednet")
+
+
+def make_z_normalizer(
+    dataloader: torch.utils.data.DataLoader,
+) -> torchvision.transforms.Normalize:
+    """Compute mean and standard deviation from a dataloader.
+
+    This function will input a dataloader, and compute the mean and standard
+    deviation by image channel.  It will work for both monochromatic, and color
+    inputs with 2, 3 or more color planes.
+
+    Parameters
+    ----------
+    dataloader
+        A torch Dataloader from which to compute the mean and std.
+
+    Returns
+    -------
+        An initialized normalizer.
+    """
+
+    # Peek the number of channels of batches in the data loader
+    batch = next(iter(dataloader))
+    channels = batch[0]["image"].shape[1]
+
+    # Initialises accumulators
+    mean = torch.zeros(channels, dtype=batch[0]["image"].dtype)
+    var = torch.zeros(channels, dtype=batch[0]["image"].dtype)
+    num_images = 0
+
+    # Evaluates mean and standard deviation
+    for batch in tqdm.tqdm(dataloader, unit="batch"):
+        data = batch[0]
+        data = data.view(data.size(0), data.size(1), -1)
+
+        num_images += data.size(0)
+        mean += data.mean(2).sum(0)
+        var += data.var(2).sum(0)
+
+    mean /= num_images
+    var /= num_images
+    std = torch.sqrt(var)
+
+    return torchvision.transforms.Normalize(mean, std)
+
+
+def make_imagenet_normalizer() -> torchvision.transforms.Normalize:
+    """Return the stock ImageNet normalisation weights from torchvision.
+
+    The weights are wrapped in a torch module.  This normalizer only works for
+    **RGB (color) images**.
+
+    Returns
+    -------
+        An initialized normalizer.
+    """
+
+    return torchvision.transforms.Normalize(
+        (0.485, 0.456, 0.406),
+        (0.229, 0.224, 0.225),
+    )
diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py
index 7a3ee9bb..98578d1d 100644
--- a/src/mednet/libs/segmentation/models/unet.py
+++ b/src/mednet/libs/segmentation/models/unet.py
@@ -142,7 +142,7 @@ class Unet(Model):
             Will not be used if the model is pretrained.
         """
         if self.pretrained:
-            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+            from .normalizer import make_imagenet_normalizer
 
             logger.warning(
                 f"ImageNet pre-trained {self.name} model - NOT "
-- 
GitLab