From 88c3ab909b2acb739d352f4efacd57236690bc29 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 11 Jun 2024 12:47:05 +0200 Subject: [PATCH] [segmentation.models] Add driu-pix model --- pyproject.toml | 1 + .../segmentation/config/models/driu_pix.py | 43 +++++ .../libs/segmentation/models/driu_pix.py | 180 ++++++++++++++++++ 3 files changed, 224 insertions(+) create mode 100644 src/mednet/libs/segmentation/config/models/driu_pix.py create mode 100644 src/mednet/libs/segmentation/models/driu_pix.py diff --git a/pyproject.toml b/pyproject.toml index 5761780f..fa28c5ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -425,6 +425,7 @@ visceral = "mednet.config.data.visceral.default" driu = "mednet.libs.segmentation.config.models.driu" driu-bn = "mednet.libs.segmentation.config.models.driu_bn" driu-od = "mednet.libs.segmentation.config.models.driu_od" +driu-pix = "mednet.libs.segmentation.config.models.driu_pix" hed = "mednet.libs.segmentation.config.models.hed" lwnet = "mednet.libs.segmentation.config.models.lwnet" m2unet = "mednet.libs.segmentation.config.models.m2unet" diff --git a/src/mednet/libs/segmentation/config/models/driu_pix.py b/src/mednet/libs/segmentation/config/models/driu_pix.py new file mode 100644 index 00000000..e0365c88 --- /dev/null +++ b/src/mednet/libs/segmentation/config/models/driu_pix.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""DRIU Network for Vessel Segmentation. + +Deep Retinal Image Understanding (DRIU), a unified framework of retinal image +analysis that provides both retinal vessel and optic disc segmentation using +deep Convolutional Neural Networks (CNNs). + +Reference: [MANINIS-2016]_ +""" + +from mednet.libs.segmentation.engine.adabound import AdaBound +from mednet.libs.segmentation.models.driu_pix import DRIUPix +from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss + +lr = 0.001 +alpha = 0.7 +betas = (0.9, 0.999) +eps = 1e-08 +weight_decay = 0 +final_lr = 0.1 +gamma = 1e-3 +eps = 1e-8 +amsbound = False + +model = DRIUPix( + loss_type=SoftJaccardBCELogitsLoss, + loss_arguments=dict(alpha=alpha), + optimizer_type=AdaBound, + optimizer_arguments=dict( + lr=lr, + betas=betas, + final_lr=final_lr, + gamma=gamma, + eps=eps, + weight_decay=weight_decay, + amsbound=amsbound, + ), + augmentation_transforms=[], + crop_size=1024, +) diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py new file mode 100644 index 00000000..08f721d0 --- /dev/null +++ b/src/mednet/libs/segmentation/models/driu_pix.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import logging +import typing + +import torch +import torch.nn +from mednet.libs.common.data.typing import TransformSequence +from mednet.libs.common.models.model import Model +from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad + +from .backbones.vgg import vgg16_for_segmentation +from .driu import ConcatFuseBlock +from .losses import SoftJaccardBCELogitsLoss +from .make_layers import UpsampleCropBlock +from .separate import separate + +logger = logging.getLogger("mednet") + + +class DRIUPIXHead(torch.nn.Module): + """DRIUPIX head module. DRIU with pixelshuffle instead of ConvTrans2D. + + Parameters + ---------- + in_channels_list : list + Number of channels for each feature map that is returned from backbone. + """ + + def __init__(self, in_channels_list=None): + super().__init__() + ( + in_conv_1_2_16, + in_upsample2, + in_upsample_4, + in_upsample_8, + ) = in_channels_list + + self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1) + # Upsample layers + self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0, pixelshuffle=True) + self.upsample4 = UpsampleCropBlock( + in_upsample_4, 16, 8, 4, 0, pixelshuffle=True + ) + self.upsample8 = UpsampleCropBlock( + in_upsample_8, 16, 16, 8, 0, pixelshuffle=True + ) + + # Concat and Fuse + self.concatfuse = ConcatFuseBlock() + + def forward(self, x): + hw = x[0] + conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16 + upsample2 = self.upsample2(x[2], hw) # side-multi2-up + upsample4 = self.upsample4(x[3], hw) # side-multi3-up + upsample8 = self.upsample8(x[4], hw) # side-multi4-up + return self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8) + + +class DRIUPix(Model): + """Implementation of the DRIU-BN model. + + Parameters + ---------- + loss_type + The loss to be used for training and evaluation. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + loss_arguments + Arguments to the loss. + optimizer_type + The type of optimizer to use for training. + optimizer_arguments + Arguments to the optimizer after ``params``. + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. + num_classes + Number of outputs (classes) for this model. + pretrained + If True, will use VGG16 pretrained weights. + crop_size + The size of the image after center cropping. + """ + + def __init__( + self, + loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, + optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_arguments: dict[str, typing.Any] = {}, + augmentation_transforms: TransformSequence = [], + num_classes: int = 1, + pretrained: bool = False, + crop_size: int = 544, + ): + super().__init__( + loss_type, + loss_arguments, + optimizer_type, + optimizer_arguments, + augmentation_transforms, + num_classes, + ) + self.name = "driu-pix" + resize_transform = ResizeMaxSide(crop_size) + + self.model_transforms = [ + resize_transform, + SquareCenterPad(), + ] + self.pretrained = pretrained + + self.backbone = vgg16_for_segmentation( + pretrained=self.pretrained, + return_features=[3, 8, 14, 22], + ) + + self.head = DRIUPIXHead([64, 128, 256, 512]) + + def forward(self, x): + if self.normalizer is not None: + x = self.normalizer(x) + x = self.backbone(x) + return self.head(x) + + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initialize the normalizer for the current model. + + This function is NOOP if ``pretrained = True`` (normalizer set to + imagenet weights, during contruction). + + Parameters + ---------- + dataloader + A torch Dataloader from which to compute the mean and std. + Will not be used if the model is pretrained. + """ + if self.pretrained: + from mednet.libs.common.models.normalizer import make_imagenet_normalizer + + logger.warning( + f"ImageNet pre-trained {self.name} model - NOT " + f"computing z-norm factors from train dataloader. " + f"Using preset factors from torchvision.", + ) + self.normalizer = make_imagenet_normalizer() + else: + self.normalizer = None + + def training_step(self, batch, batch_idx): + images = batch[0] + ground_truths = batch[1]["target"] + masks = batch[1]["mask"] + + outputs = self(self._augmentation_transforms(images)) + return self._train_loss(outputs, ground_truths, masks) + + def validation_step(self, batch, batch_idx): + images = batch[0] + ground_truths = batch[1]["target"] + masks = batch[1]["mask"] + + outputs = self(images) + return self._validation_loss(outputs, ground_truths, masks) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + output = self(batch[0])[1] + probabilities = torch.sigmoid(output) + return separate((probabilities, batch[1])) + + def configure_optimizers(self): + return self._optimizer_type(self.parameters(), **self._optimizer_arguments) -- GitLab