From d4d28b03566d2e52a6d0bccee3c1e14262079e3d Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 11 Jun 2024 12:23:10 +0200
Subject: [PATCH] [segmentation.models] Add m2unet model

---
 pyproject.toml                                |   1 +
 .../libs/segmentation/config/models/m2unet.py |  48 ++++
 src/mednet/libs/segmentation/models/m2unet.py | 245 ++++++++++++++++++
 3 files changed, 294 insertions(+)
 create mode 100644 src/mednet/libs/segmentation/config/models/m2unet.py
 create mode 100644 src/mednet/libs/segmentation/models/m2unet.py

diff --git a/pyproject.toml b/pyproject.toml
index e6cf3efb..8049272e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -425,6 +425,7 @@ visceral = "mednet.config.data.visceral.default"
 driu = "mednet.libs.segmentation.config.models.driu"
 hed = "mednet.libs.segmentation.config.models.hed"
 lwnet = "mednet.libs.segmentation.config.models.lwnet"
+m2unet = "mednet.libs.segmentation.config.models.m2unet"
 unet = "mednet.libs.segmentation.config.models.unet"
 
 # chase-db1 - retinography
diff --git a/src/mednet/libs/segmentation/config/models/m2unet.py b/src/mednet/libs/segmentation/config/models/m2unet.py
new file mode 100644
index 00000000..bb360709
--- /dev/null
+++ b/src/mednet/libs/segmentation/config/models/m2unet.py
@@ -0,0 +1,48 @@
+# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""MobileNetV2 U-Net model for image segmentation.
+
+The MobileNetV2 architecture is based on an inverted residual structure where
+the input and output of the residual block are thin bottleneck layers opposite
+to traditional residual models which use expanded representations in the input
+an MobileNetV2 uses lightweight depthwise convolutions to filter features in
+the intermediate expansion layer.  This model implements a MobileNetV2 U-Net
+model, henceforth named M2U-Net, combining the strenghts of U-Net for medical
+segmentation applications and the speed of MobileNetV2 networks.
+
+References: [SANDLER-2018]_, [RONNEBERGER-2015]_
+"""
+
+from mednet.libs.segmentation.engine.adabound import AdaBound
+from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
+from mednet.libs.segmentation.models.m2unet import M2UNET
+
+lr = 0.001
+alpha = 0.7
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
+crop_size = 544
+
+model = M2UNET(
+    loss_type=SoftJaccardBCELogitsLoss,
+    loss_arguments=dict(alpha=alpha),
+    optimizer_type=AdaBound,
+    optimizer_arguments=dict(
+        lr=lr,
+        betas=betas,
+        final_lr=final_lr,
+        gamma=gamma,
+        eps=eps,
+        weight_decay=weight_decay,
+        amsbound=amsbound,
+    ),
+    augmentation_transforms=[],
+    crop_size=crop_size,
+)
diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py
new file mode 100644
index 00000000..76042ba2
--- /dev/null
+++ b/src/mednet/libs/segmentation/models/m2unet.py
@@ -0,0 +1,245 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import logging
+import typing
+
+import torch
+import torch.nn
+from mednet.libs.common.data.typing import TransformSequence
+from mednet.libs.common.models.model import Model
+from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
+from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
+from torchvision.models.mobilenetv2 import InvertedResidual
+
+from .backbones.mobilenetv2 import mobilenet_v2_for_segmentation
+from .separate import separate
+
+logger = logging.getLogger("mednet")
+
+
+class DecoderBlock(torch.nn.Module):
+    """Decoder block: upsample and concatenate with features maps from the
+    encoder part.
+
+    Parameters
+    ----------
+    up_in_c
+        Number of input channels.
+    x_in_c
+        Number of cat channels.
+    upsamplemode
+        Mode to use for upsampling.
+    expand_ratio
+        The expand ratio.
+    """
+
+    def __init__(self, up_in_c, x_in_c, upsamplemode="bilinear", expand_ratio=0.15):
+        super().__init__()
+        self.upsample = torch.nn.Upsample(
+            scale_factor=2, mode=upsamplemode, align_corners=False
+        )  # H, W -> 2H, 2W
+        self.ir1 = InvertedResidual(
+            up_in_c + x_in_c,
+            (x_in_c + up_in_c) // 2,
+            stride=1,
+            expand_ratio=expand_ratio,
+        )
+
+    def forward(self, up_in, x_in):
+        up_out = self.upsample(up_in)
+        cat_x = torch.cat([up_out, x_in], dim=1)
+        return self.ir1(cat_x)
+
+
+class LastDecoderBlock(torch.nn.Module):
+    """Last decoder block.
+
+    Parameters
+    ----------
+    x_in_c
+        Number of cat channels.
+    upsamplemode
+        Mode to use for upsampling.
+    expand_ratio
+        The expand ratio.
+    """
+
+    def __init__(self, x_in_c, upsamplemode="bilinear", expand_ratio=0.15):
+        super().__init__()
+        self.upsample = torch.nn.Upsample(
+            scale_factor=2, mode=upsamplemode, align_corners=False
+        )  # H, W -> 2H, 2W
+        self.ir1 = InvertedResidual(x_in_c, 1, stride=1, expand_ratio=expand_ratio)
+
+    def forward(self, up_in, x_in):
+        up_out = self.upsample(up_in)
+        cat_x = torch.cat([up_out, x_in], dim=1)
+        return self.ir1(cat_x)
+
+
+class M2UNetHead(torch.nn.Module):
+    """M2U-Net head module.
+
+    Parameters
+    ----------
+    in_channels_list
+        Number of channels for each feature map that is returned from backbone.
+    upsamplemode
+        Mode to use for upsampling.
+    expand_ratio
+        The expand ratio.
+    """
+
+    def __init__(
+        self, in_channels_list=None, upsamplemode="bilinear", expand_ratio=0.15
+    ):
+        super().__init__()
+
+        # Decoder
+        self.decode4 = DecoderBlock(96, 32, upsamplemode, expand_ratio)
+        self.decode3 = DecoderBlock(64, 24, upsamplemode, expand_ratio)
+        self.decode2 = DecoderBlock(44, 16, upsamplemode, expand_ratio)
+        self.decode1 = LastDecoderBlock(33, upsamplemode, expand_ratio)
+
+        # initilaize weights
+        self._initialize_weights()
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                torch.nn.init.kaiming_uniform_(m.weight, a=1)
+                if m.bias is not None:
+                    torch.nn.init.constant_(m.bias, 0)
+            elif isinstance(m, torch.nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+    def forward(self, x):
+        decode4 = self.decode4(x[5], x[4])  # 96, 32
+        decode3 = self.decode3(decode4, x[3])  # 64, 24
+        decode2 = self.decode2(decode3, x[2])  # 44, 16
+        return self.decode1(decode2, x[1])  # 30, 3
+
+
+class M2UNET(Model):
+    """Implementation of the M2UNET model.
+
+    Parameters
+    ----------
+    loss_type
+        The loss to be used for training and evaluation.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+    loss_arguments
+        Arguments to the loss.
+    optimizer_type
+        The type of optimizer to use for training.
+    optimizer_arguments
+        Arguments to the optimizer after ``params``.
+    augmentation_transforms
+        An optional sequence of torch modules containing transforms to be
+        applied on the input **before** it is fed into the network.
+    num_classes
+        Number of outputs (classes) for this model.
+    pretrained
+        If True, will use VGG16 pretrained weights.
+    crop_size
+        The size of the image after center cropping.
+    """
+
+    def __init__(
+        self,
+        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
+        optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
+        optimizer_arguments: dict[str, typing.Any] = {},
+        augmentation_transforms: TransformSequence = [],
+        num_classes: int = 1,
+        pretrained: bool = False,
+        crop_size: int = 544,
+    ):
+        super().__init__(
+            loss_type,
+            loss_arguments,
+            optimizer_type,
+            optimizer_arguments,
+            augmentation_transforms,
+            num_classes,
+        )
+
+        self.name = "m2unet"
+
+        resize_transform = ResizeMaxSide(crop_size)
+
+        self.model_transforms = [
+            resize_transform,
+            SquareCenterPad(),
+        ]
+
+        self.pretrained = pretrained
+
+        self.backbone = mobilenet_v2_for_segmentation(
+            pretrained=self.pretrained,
+            return_features=[1, 3, 6, 13],
+        )
+
+        self.head = M2UNetHead(in_channels_list=[16, 24, 32, 96])
+
+    def forward(self, x):
+        if self.normalizer is not None:
+            x = self.normalizer(x)
+        x = self.backbone(x)
+        return self.head(x)
+
+    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
+        """Initialize the normalizer for the current model.
+
+        This function is NOOP if ``pretrained = True`` (normalizer set to
+        imagenet weights, during contruction).
+
+        Parameters
+        ----------
+        dataloader
+            A torch Dataloader from which to compute the mean and std.
+            Will not be used if the model is pretrained.
+        """
+        if self.pretrained:
+            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+
+            logger.warning(
+                f"ImageNet pre-trained {self.name} model - NOT "
+                f"computing z-norm factors from train dataloader. "
+                f"Using preset factors from torchvision.",
+            )
+            self.normalizer = make_imagenet_normalizer()
+        else:
+            self.normalizer = None
+
+    def training_step(self, batch, batch_idx):
+        images = batch[0]
+        ground_truths = batch[1]["target"]
+        masks = batch[1]["mask"]
+
+        outputs = self(self._augmentation_transforms(images))
+        return self._train_loss(outputs, ground_truths, masks)
+
+    def validation_step(self, batch, batch_idx):
+        images = batch[0]
+        ground_truths = batch[1]["target"]
+        masks = batch[1]["mask"]
+
+        outputs = self(images)
+        return self._validation_loss(outputs, ground_truths, masks)
+
+    def predict_step(self, batch, batch_idx, dataloader_idx=0):
+        output = self(batch[0])[1]
+        probabilities = torch.sigmoid(output)
+        return separate((probabilities, batch[1]))
+
+    def configure_optimizers(self):
+        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
-- 
GitLab