Skip to content
Snippets Groups Projects
Commit d4d28b03 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation.models] Add m2unet model

parent 91fa4ecf
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -425,6 +425,7 @@ visceral = "mednet.config.data.visceral.default"
driu = "mednet.libs.segmentation.config.models.driu"
hed = "mednet.libs.segmentation.config.models.hed"
lwnet = "mednet.libs.segmentation.config.models.lwnet"
m2unet = "mednet.libs.segmentation.config.models.m2unet"
unet = "mednet.libs.segmentation.config.models.unet"
# chase-db1 - retinography
......
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""MobileNetV2 U-Net model for image segmentation.
The MobileNetV2 architecture is based on an inverted residual structure where
the input and output of the residual block are thin bottleneck layers opposite
to traditional residual models which use expanded representations in the input
an MobileNetV2 uses lightweight depthwise convolutions to filter features in
the intermediate expansion layer. This model implements a MobileNetV2 U-Net
model, henceforth named M2U-Net, combining the strenghts of U-Net for medical
segmentation applications and the speed of MobileNetV2 networks.
References: [SANDLER-2018]_, [RONNEBERGER-2015]_
"""
from mednet.libs.segmentation.engine.adabound import AdaBound
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
from mednet.libs.segmentation.models.m2unet import M2UNET
lr = 0.001
alpha = 0.7
betas = (0.9, 0.999)
eps = 1e-08
weight_decay = 0
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
crop_size = 544
model = M2UNET(
loss_type=SoftJaccardBCELogitsLoss,
loss_arguments=dict(alpha=alpha),
optimizer_type=AdaBound,
optimizer_arguments=dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
),
augmentation_transforms=[],
crop_size=crop_size,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import typing
import torch
import torch.nn
from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
from torchvision.models.mobilenetv2 import InvertedResidual
from .backbones.mobilenetv2 import mobilenet_v2_for_segmentation
from .separate import separate
logger = logging.getLogger("mednet")
class DecoderBlock(torch.nn.Module):
"""Decoder block: upsample and concatenate with features maps from the
encoder part.
Parameters
----------
up_in_c
Number of input channels.
x_in_c
Number of cat channels.
upsamplemode
Mode to use for upsampling.
expand_ratio
The expand ratio.
"""
def __init__(self, up_in_c, x_in_c, upsamplemode="bilinear", expand_ratio=0.15):
super().__init__()
self.upsample = torch.nn.Upsample(
scale_factor=2, mode=upsamplemode, align_corners=False
) # H, W -> 2H, 2W
self.ir1 = InvertedResidual(
up_in_c + x_in_c,
(x_in_c + up_in_c) // 2,
stride=1,
expand_ratio=expand_ratio,
)
def forward(self, up_in, x_in):
up_out = self.upsample(up_in)
cat_x = torch.cat([up_out, x_in], dim=1)
return self.ir1(cat_x)
class LastDecoderBlock(torch.nn.Module):
"""Last decoder block.
Parameters
----------
x_in_c
Number of cat channels.
upsamplemode
Mode to use for upsampling.
expand_ratio
The expand ratio.
"""
def __init__(self, x_in_c, upsamplemode="bilinear", expand_ratio=0.15):
super().__init__()
self.upsample = torch.nn.Upsample(
scale_factor=2, mode=upsamplemode, align_corners=False
) # H, W -> 2H, 2W
self.ir1 = InvertedResidual(x_in_c, 1, stride=1, expand_ratio=expand_ratio)
def forward(self, up_in, x_in):
up_out = self.upsample(up_in)
cat_x = torch.cat([up_out, x_in], dim=1)
return self.ir1(cat_x)
class M2UNetHead(torch.nn.Module):
"""M2U-Net head module.
Parameters
----------
in_channels_list
Number of channels for each feature map that is returned from backbone.
upsamplemode
Mode to use for upsampling.
expand_ratio
The expand ratio.
"""
def __init__(
self, in_channels_list=None, upsamplemode="bilinear", expand_ratio=0.15
):
super().__init__()
# Decoder
self.decode4 = DecoderBlock(96, 32, upsamplemode, expand_ratio)
self.decode3 = DecoderBlock(64, 24, upsamplemode, expand_ratio)
self.decode2 = DecoderBlock(44, 16, upsamplemode, expand_ratio)
self.decode1 = LastDecoderBlock(33, upsamplemode, expand_ratio)
# initilaize weights
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.kaiming_uniform_(m.weight, a=1)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, torch.nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
decode4 = self.decode4(x[5], x[4]) # 96, 32
decode3 = self.decode3(decode4, x[3]) # 64, 24
decode2 = self.decode2(decode3, x[2]) # 44, 16
return self.decode1(decode2, x[1]) # 30, 3
class M2UNET(Model):
"""Implementation of the M2UNET model.
Parameters
----------
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
num_classes
Number of outputs (classes) for this model.
pretrained
If True, will use VGG16 pretrained weights.
crop_size
The size of the image after center cropping.
"""
def __init__(
self,
loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
pretrained: bool = False,
crop_size: int = 544,
):
super().__init__(
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
num_classes,
)
self.name = "m2unet"
resize_transform = ResizeMaxSide(crop_size)
self.model_transforms = [
resize_transform,
SquareCenterPad(),
]
self.pretrained = pretrained
self.backbone = mobilenet_v2_for_segmentation(
pretrained=self.pretrained,
return_features=[1, 3, 6, 13],
)
self.head = M2UNetHead(in_channels_list=[16, 24, 32, 96])
def forward(self, x):
if self.normalizer is not None:
x = self.normalizer(x)
x = self.backbone(x)
return self.head(x)
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the normalizer for the current model.
This function is NOOP if ``pretrained = True`` (normalizer set to
imagenet weights, during contruction).
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
Will not be used if the model is pretrained.
"""
if self.pretrained:
from mednet.libs.common.models.normalizer import make_imagenet_normalizer
logger.warning(
f"ImageNet pre-trained {self.name} model - NOT "
f"computing z-norm factors from train dataloader. "
f"Using preset factors from torchvision.",
)
self.normalizer = make_imagenet_normalizer()
else:
self.normalizer = None
def training_step(self, batch, batch_idx):
images = batch[0]
ground_truths = batch[1]["target"]
masks = batch[1]["mask"]
outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx):
images = batch[0]
ground_truths = batch[1]["target"]
masks = batch[1]["mask"]
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0])[1]
probabilities = torch.sigmoid(output)
return separate((probabilities, batch[1]))
def configure_optimizers(self):
return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment