From d4d28b03566d2e52a6d0bccee3c1e14262079e3d Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 11 Jun 2024 12:23:10 +0200 Subject: [PATCH] [segmentation.models] Add m2unet model --- pyproject.toml | 1 + .../libs/segmentation/config/models/m2unet.py | 48 ++++ src/mednet/libs/segmentation/models/m2unet.py | 245 ++++++++++++++++++ 3 files changed, 294 insertions(+) create mode 100644 src/mednet/libs/segmentation/config/models/m2unet.py create mode 100644 src/mednet/libs/segmentation/models/m2unet.py diff --git a/pyproject.toml b/pyproject.toml index e6cf3efb..8049272e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -425,6 +425,7 @@ visceral = "mednet.config.data.visceral.default" driu = "mednet.libs.segmentation.config.models.driu" hed = "mednet.libs.segmentation.config.models.hed" lwnet = "mednet.libs.segmentation.config.models.lwnet" +m2unet = "mednet.libs.segmentation.config.models.m2unet" unet = "mednet.libs.segmentation.config.models.unet" # chase-db1 - retinography diff --git a/src/mednet/libs/segmentation/config/models/m2unet.py b/src/mednet/libs/segmentation/config/models/m2unet.py new file mode 100644 index 00000000..bb360709 --- /dev/null +++ b/src/mednet/libs/segmentation/config/models/m2unet.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""MobileNetV2 U-Net model for image segmentation. + +The MobileNetV2 architecture is based on an inverted residual structure where +the input and output of the residual block are thin bottleneck layers opposite +to traditional residual models which use expanded representations in the input +an MobileNetV2 uses lightweight depthwise convolutions to filter features in +the intermediate expansion layer. This model implements a MobileNetV2 U-Net +model, henceforth named M2U-Net, combining the strenghts of U-Net for medical +segmentation applications and the speed of MobileNetV2 networks. + +References: [SANDLER-2018]_, [RONNEBERGER-2015]_ +""" + +from mednet.libs.segmentation.engine.adabound import AdaBound +from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss +from mednet.libs.segmentation.models.m2unet import M2UNET + +lr = 0.001 +alpha = 0.7 +betas = (0.9, 0.999) +eps = 1e-08 +weight_decay = 0 +final_lr = 0.1 +gamma = 1e-3 +eps = 1e-8 +amsbound = False +crop_size = 544 + +model = M2UNET( + loss_type=SoftJaccardBCELogitsLoss, + loss_arguments=dict(alpha=alpha), + optimizer_type=AdaBound, + optimizer_arguments=dict( + lr=lr, + betas=betas, + final_lr=final_lr, + gamma=gamma, + eps=eps, + weight_decay=weight_decay, + amsbound=amsbound, + ), + augmentation_transforms=[], + crop_size=crop_size, +) diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py new file mode 100644 index 00000000..76042ba2 --- /dev/null +++ b/src/mednet/libs/segmentation/models/m2unet.py @@ -0,0 +1,245 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import typing + +import torch +import torch.nn +from mednet.libs.common.data.typing import TransformSequence +from mednet.libs.common.models.model import Model +from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad +from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss +from torchvision.models.mobilenetv2 import InvertedResidual + +from .backbones.mobilenetv2 import mobilenet_v2_for_segmentation +from .separate import separate + +logger = logging.getLogger("mednet") + + +class DecoderBlock(torch.nn.Module): + """Decoder block: upsample and concatenate with features maps from the + encoder part. + + Parameters + ---------- + up_in_c + Number of input channels. + x_in_c + Number of cat channels. + upsamplemode + Mode to use for upsampling. + expand_ratio + The expand ratio. + """ + + def __init__(self, up_in_c, x_in_c, upsamplemode="bilinear", expand_ratio=0.15): + super().__init__() + self.upsample = torch.nn.Upsample( + scale_factor=2, mode=upsamplemode, align_corners=False + ) # H, W -> 2H, 2W + self.ir1 = InvertedResidual( + up_in_c + x_in_c, + (x_in_c + up_in_c) // 2, + stride=1, + expand_ratio=expand_ratio, + ) + + def forward(self, up_in, x_in): + up_out = self.upsample(up_in) + cat_x = torch.cat([up_out, x_in], dim=1) + return self.ir1(cat_x) + + +class LastDecoderBlock(torch.nn.Module): + """Last decoder block. + + Parameters + ---------- + x_in_c + Number of cat channels. + upsamplemode + Mode to use for upsampling. + expand_ratio + The expand ratio. + """ + + def __init__(self, x_in_c, upsamplemode="bilinear", expand_ratio=0.15): + super().__init__() + self.upsample = torch.nn.Upsample( + scale_factor=2, mode=upsamplemode, align_corners=False + ) # H, W -> 2H, 2W + self.ir1 = InvertedResidual(x_in_c, 1, stride=1, expand_ratio=expand_ratio) + + def forward(self, up_in, x_in): + up_out = self.upsample(up_in) + cat_x = torch.cat([up_out, x_in], dim=1) + return self.ir1(cat_x) + + +class M2UNetHead(torch.nn.Module): + """M2U-Net head module. + + Parameters + ---------- + in_channels_list + Number of channels for each feature map that is returned from backbone. + upsamplemode + Mode to use for upsampling. + expand_ratio + The expand ratio. + """ + + def __init__( + self, in_channels_list=None, upsamplemode="bilinear", expand_ratio=0.15 + ): + super().__init__() + + # Decoder + self.decode4 = DecoderBlock(96, 32, upsamplemode, expand_ratio) + self.decode3 = DecoderBlock(64, 24, upsamplemode, expand_ratio) + self.decode2 = DecoderBlock(44, 16, upsamplemode, expand_ratio) + self.decode1 = LastDecoderBlock(33, upsamplemode, expand_ratio) + + # initilaize weights + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, torch.nn.Conv2d): + torch.nn.init.kaiming_uniform_(m.weight, a=1) + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + elif isinstance(m, torch.nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, x): + decode4 = self.decode4(x[5], x[4]) # 96, 32 + decode3 = self.decode3(decode4, x[3]) # 64, 24 + decode2 = self.decode2(decode3, x[2]) # 44, 16 + return self.decode1(decode2, x[1]) # 30, 3 + + +class M2UNET(Model): + """Implementation of the M2UNET model. + + Parameters + ---------- + loss_type + The loss to be used for training and evaluation. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + loss_arguments + Arguments to the loss. + optimizer_type + The type of optimizer to use for training. + optimizer_arguments + Arguments to the optimizer after ``params``. + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. + num_classes + Number of outputs (classes) for this model. + pretrained + If True, will use VGG16 pretrained weights. + crop_size + The size of the image after center cropping. + """ + + def __init__( + self, + loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, + optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_arguments: dict[str, typing.Any] = {}, + augmentation_transforms: TransformSequence = [], + num_classes: int = 1, + pretrained: bool = False, + crop_size: int = 544, + ): + super().__init__( + loss_type, + loss_arguments, + optimizer_type, + optimizer_arguments, + augmentation_transforms, + num_classes, + ) + + self.name = "m2unet" + + resize_transform = ResizeMaxSide(crop_size) + + self.model_transforms = [ + resize_transform, + SquareCenterPad(), + ] + + self.pretrained = pretrained + + self.backbone = mobilenet_v2_for_segmentation( + pretrained=self.pretrained, + return_features=[1, 3, 6, 13], + ) + + self.head = M2UNetHead(in_channels_list=[16, 24, 32, 96]) + + def forward(self, x): + if self.normalizer is not None: + x = self.normalizer(x) + x = self.backbone(x) + return self.head(x) + + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initialize the normalizer for the current model. + + This function is NOOP if ``pretrained = True`` (normalizer set to + imagenet weights, during contruction). + + Parameters + ---------- + dataloader + A torch Dataloader from which to compute the mean and std. + Will not be used if the model is pretrained. + """ + if self.pretrained: + from mednet.libs.common.models.normalizer import make_imagenet_normalizer + + logger.warning( + f"ImageNet pre-trained {self.name} model - NOT " + f"computing z-norm factors from train dataloader. " + f"Using preset factors from torchvision.", + ) + self.normalizer = make_imagenet_normalizer() + else: + self.normalizer = None + + def training_step(self, batch, batch_idx): + images = batch[0] + ground_truths = batch[1]["target"] + masks = batch[1]["mask"] + + outputs = self(self._augmentation_transforms(images)) + return self._train_loss(outputs, ground_truths, masks) + + def validation_step(self, batch, batch_idx): + images = batch[0] + ground_truths = batch[1]["target"] + masks = batch[1]["mask"] + + outputs = self(images) + return self._validation_loss(outputs, ground_truths, masks) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + output = self(batch[0])[1] + probabilities = torch.sigmoid(output) + return separate((probabilities, batch[1])) + + def configure_optimizers(self): + return self._optimizer_type(self.parameters(), **self._optimizer_arguments) -- GitLab