diff --git a/pyproject.toml b/pyproject.toml index 3a548d8ab87f59c4e1a2e6cec6983b1d222102e7..0f1681ca11af5d77db441936b94570df5197f4fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -421,6 +421,8 @@ visceral = "mednet.config.data.visceral.default" [project.entry-points."mednet.libs.segmentation.config"] +# models +driu = "mednet.libs.segmentation.config.models.driu" lwnet = "mednet.libs.segmentation.config.models.lwnet" unet = "mednet.libs.segmentation.config.models.unet" diff --git a/src/mednet/libs/segmentation/config/models/driu.py b/src/mednet/libs/segmentation/config/models/driu.py new file mode 100644 index 0000000000000000000000000000000000000000..498f6324f1eeb775742b7b9c261bbbd19ca93df7 --- /dev/null +++ b/src/mednet/libs/segmentation/config/models/driu.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Little W-Net for image segmentation. + +The Little W-Net architecture contains roughly around 70k parameters and +closely matches (or outperforms) other more complex techniques. + +Reference: [GALDRAN-2020]_ +""" + +from mednet.libs.segmentation.engine.adabound import AdaBound +from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss +from mednet.libs.segmentation.models.unet import Unet + +lr = 0.001 +alpha = 0.7 +betas = (0.9, 0.999) +eps = 1e-08 +weight_decay = 0 +final_lr = 0.1 +gamma = 1e-3 +eps = 1e-8 +amsbound = False + +model = Unet( + loss_type=SoftJaccardBCELogitsLoss, + loss_arguments=dict(alpha=alpha), + optimizer_type=AdaBound, + optimizer_arguments=dict( + lr=lr, + betas=betas, + final_lr=final_lr, + gamma=gamma, + eps=eps, + weight_decay=weight_decay, + amsbound=amsbound, + ), + augmentation_transforms=[], + crop_size=1024, +) diff --git a/src/mednet/libs/segmentation/config/models/unet.py b/src/mednet/libs/segmentation/config/models/unet.py index 9336d2ef1e93549a1e468b460998a7eb49aec90e..498f6324f1eeb775742b7b9c261bbbd19ca93df7 100644 --- a/src/mednet/libs/segmentation/config/models/unet.py +++ b/src/mednet/libs/segmentation/config/models/unet.py @@ -13,7 +13,7 @@ from mednet.libs.segmentation.engine.adabound import AdaBound from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss from mednet.libs.segmentation.models.unet import Unet -lr = 0.01 # start +lr = 0.001 alpha = 0.7 betas = (0.9, 0.999) eps = 1e-08 diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py new file mode 100644 index 0000000000000000000000000000000000000000..74eba7fc6dcd66411f04f43e935fa54ab8bfa0de --- /dev/null +++ b/src/mednet/libs/segmentation/models/driu.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import logging +import typing + +import torch +import torch.nn +from mednet.libs.common.data.typing import TransformSequence +from mednet.libs.common.models.model import Model + +from .backbones.vgg import vgg16_for_segmentation +from .losses import SoftJaccardBCELogitsLoss +from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform + +logger = logging.getLogger("mednet") + + +class ConcatFuseBlock(torch.nn.Module): + """Takes in four feature maps with 16 channels each, concatenates them and + applies a 1x1 convolution with 1 output channel. + """ + + def __init__(self): + super().__init__() + self.conv = conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0) + + def forward(self, x1, x2, x3, x4): + x_cat = torch.cat([x1, x2, x3, x4], dim=1) + return self.conv(x_cat) + + +class DRIUHead(torch.nn.Module): + """DRIU head module. + + Based on paper by [MANINIS-2016]_. + + Parameters + ---------- + in_channels_list + Number of channels for each feature map that is returned from backbone. + """ + + def __init__(self, in_channels_list=None): + super().__init__() + ( + in_conv_1_2_16, + in_upsample2, + in_upsample_4, + in_upsample_8, + ) = in_channels_list + + self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1) + # Upsample layers + self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0) + self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0) + self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0) + + # Concat and Fuse + self.concatfuse = ConcatFuseBlock() + + def forward(self, x): + hw = x[0] + conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16 + upsample2 = self.upsample2(x[2], hw) # side-multi2-up + upsample4 = self.upsample4(x[3], hw) # side-multi3-up + upsample8 = self.upsample8(x[4], hw) # side-multi4-up + return self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8) + + +class DRIU(Model): + """Build DRIU for vessel segmentation by adding backbone and head + together. + + Parameters + ---------- + loss_type + The loss to be used for training and evaluation. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + loss_arguments + Arguments to the loss. + optimizer_type + The type of optimizer to use for training. + optimizer_arguments + Arguments to the optimizer after ``params``. + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. + num_classes + Number of outputs (classes) for this model. + pretrained + If True, will use VGG16 pretrained weights. + + Returns + ------- + module : :py:class:`torch.nn.Module` + Network model for DRIU (vessel segmentation). + """ + + def __init__( + self, + loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, + optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_arguments: dict[str, typing.Any] = {}, + augmentation_transforms: TransformSequence = [], + num_classes: int = 1, + pretrained: bool = False, + ): + super().__init__( + loss_type, + loss_arguments, + optimizer_type, + optimizer_arguments, + augmentation_transforms, + num_classes, + ) + self.name = "driu" + self.model_transforms: TransformSequence = [] + self.pretrained = pretrained + + self.backbone = vgg16_for_segmentation( + pretrained=self.pretrained, + return_features=[3, 8, 14, 22], + ) + + self.head = DRIUHead([64, 128, 256, 512]) + + def forward(self, x): + if self.normalizer is not None: + x = self.normalizer(x) + x = self.backbone(x) + return self.head(x) + + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initialize the normalizer for the current model. + + This function is NOOP if ``pretrained = True`` (normalizer set to + imagenet weights, during contruction). + + Parameters + ---------- + dataloader + A torch Dataloader from which to compute the mean and std. + Will not be used if the model is pretrained. + """ + if self.pretrained: + from mednet.libs.common.models.normalizer import make_imagenet_normalizer + + logger.warning( + f"ImageNet pre-trained {self.name} model - NOT " + f"computing z-norm factors from train dataloader. " + f"Using preset factors from torchvision.", + ) + self.normalizer = make_imagenet_normalizer() + else: + self.normalizer = None + + def training_step(self, batch, batch_idx): + images = batch[0] + ground_truths = batch[1]["target"] + masks = batch[1]["mask"] + + outputs = self(self._augmentation_transforms(images)) + return self._train_loss(outputs, ground_truths, masks) + + def validation_step(self, batch, batch_idx): + images = batch[0] + ground_truths = batch[1]["target"] + masks = batch[1]["mask"] + + outputs = self(images) + return self._validation_loss(outputs, ground_truths, masks)