From f31df6a2e2d86ec92b1bf4edb84aad2f752725f3 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 11 Jun 2024 11:31:18 +0200 Subject: [PATCH] [segmentation.models] Add hed model --- pyproject.toml | 1 + .../libs/segmentation/config/models/hed.py | 35 ++++ src/mednet/libs/segmentation/models/hed.py | 188 ++++++++++++++++++ 3 files changed, 224 insertions(+) create mode 100644 src/mednet/libs/segmentation/config/models/hed.py create mode 100644 src/mednet/libs/segmentation/models/hed.py diff --git a/pyproject.toml b/pyproject.toml index 0f1681ca..e6cf3efb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -423,6 +423,7 @@ visceral = "mednet.config.data.visceral.default" # models driu = "mednet.libs.segmentation.config.models.driu" +hed = "mednet.libs.segmentation.config.models.hed" lwnet = "mednet.libs.segmentation.config.models.lwnet" unet = "mednet.libs.segmentation.config.models.unet" diff --git a/src/mednet/libs/segmentation/config/models/hed.py b/src/mednet/libs/segmentation/config/models/hed.py new file mode 100644 index 00000000..0c88c9c2 --- /dev/null +++ b/src/mednet/libs/segmentation/config/models/hed.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from mednet.libs.segmentation.engine.adabound import AdaBound +from mednet.libs.segmentation.models.hed import HED +from mednet.libs.segmentation.models.losses import MultiSoftJaccardBCELogitsLoss + +lr = 0.001 +alpha = 0.7 +betas = (0.9, 0.999) +eps = 1e-08 +weight_decay = 0 +final_lr = 0.1 +gamma = 1e-3 +eps = 1e-8 +amsbound = False +crop_size = 544 + +model = HED( + loss_type=MultiSoftJaccardBCELogitsLoss, + loss_arguments=dict(alpha=alpha), + optimizer_type=AdaBound, + optimizer_arguments=dict( + lr=lr, + betas=betas, + final_lr=final_lr, + gamma=gamma, + eps=eps, + weight_decay=weight_decay, + amsbound=amsbound, + ), + augmentation_transforms=[], + crop_size=crop_size, +) diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py new file mode 100644 index 00000000..bbefa8a4 --- /dev/null +++ b/src/mednet/libs/segmentation/models/hed.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import typing + +import torch +import torch.nn +from mednet.libs.common.data.typing import TransformSequence +from mednet.libs.common.models.model import Model +from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad + +from .backbones.vgg import vgg16_for_segmentation +from .losses import MultiSoftJaccardBCELogitsLoss +from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform + +logger = logging.getLogger("mednet") + + +class ConcatFuseBlock(torch.nn.Module): + """Take in five feature maps with one channel each, concatenates thems and + applies a 1x1 convolution with 1 output channel. + """ + + def __init__(self): + super().__init__() + self.conv = conv_with_kaiming_uniform(5, 1, 1, 1, 0) + + def forward(self, x1, x2, x3, x4, x5): + x_cat = torch.cat([x1, x2, x3, x4, x5], dim=1) + return self.conv(x_cat) + + +class HEDHead(torch.nn.Module): + """HED head module. + + Parameters + ---------- + in_channels_list : list + Number of channels for each feature map that is returned from backbone. + """ + + def __init__(self, in_channels_list=None): + super().__init__() + ( + in_conv_1_2_16, + in_upsample2, + in_upsample_4, + in_upsample_8, + in_upsample_16, + ) = in_channels_list + + self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 1, 3, 1, 1) + # Upsample + self.upsample2 = UpsampleCropBlock(in_upsample2, 1, 4, 2, 0) + self.upsample4 = UpsampleCropBlock(in_upsample_4, 1, 8, 4, 0) + self.upsample8 = UpsampleCropBlock(in_upsample_8, 1, 16, 8, 0) + self.upsample16 = UpsampleCropBlock(in_upsample_16, 1, 32, 16, 0) + # Concat and Fuse + self.concatfuse = ConcatFuseBlock() + + def forward(self, x): + hw = x[0] + conv1_2_16 = self.conv1_2_16(x[1]) + upsample2 = self.upsample2(x[2], hw) + upsample4 = self.upsample4(x[3], hw) + upsample8 = self.upsample8(x[4], hw) + upsample16 = self.upsample16(x[5], hw) + concatfuse = self.concatfuse( + conv1_2_16, upsample2, upsample4, upsample8, upsample16 + ) + + return (upsample2, upsample4, upsample8, upsample16, concatfuse) + + +class HED(Model): + """Implementation of the HED model. + + Parameters + ---------- + loss_type + The loss to be used for training and evaluation. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + loss_arguments + Arguments to the loss. + optimizer_type + The type of optimizer to use for training. + optimizer_arguments + Arguments to the optimizer after ``params``. + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. + num_classes + Number of outputs (classes) for this model. + pretrained + If True, will use VGG16 pretrained weights. + crop_size + The size of the image after center cropping. + """ + + def __init__( + self, + loss_type: torch.nn.Module = MultiSoftJaccardBCELogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, + optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_arguments: dict[str, typing.Any] = {}, + augmentation_transforms: TransformSequence = [], + num_classes: int = 1, + pretrained: bool = False, + crop_size: int = 544, + ): + super().__init__( + loss_type, + loss_arguments, + optimizer_type, + optimizer_arguments, + augmentation_transforms, + num_classes, + ) + + self.name = "hed" + + resize_transform = ResizeMaxSide(crop_size) + + self.model_transforms = [ + resize_transform, + SquareCenterPad(), + ] + + self.pretrained = pretrained + + self.backbone = vgg16_for_segmentation( + pretrained=self.pretrained, + return_features=[3, 8, 14, 22, 29], + ) + + self.head = HEDHead([64, 128, 256, 512, 512]) + + def forward(self, x): + if self.normalizer is not None: + x = self.normalizer(x) + x = self.backbone(x) + return self.head(x) + + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initialize the normalizer for the current model. + + This function is NOOP if ``pretrained = True`` (normalizer set to + imagenet weights, during contruction). + + Parameters + ---------- + dataloader + A torch Dataloader from which to compute the mean and std. + Will not be used if the model is pretrained. + """ + if self.pretrained: + from mednet.libs.common.models.normalizer import make_imagenet_normalizer + + logger.warning( + f"ImageNet pre-trained {self.name} model - NOT " + f"computing z-norm factors from train dataloader. " + f"Using preset factors from torchvision.", + ) + self.normalizer = make_imagenet_normalizer() + else: + self.normalizer = None + + def training_step(self, batch, batch_idx): + images = batch[0] + ground_truths = batch[1]["target"] + masks = batch[1]["mask"] + + outputs = self(self._augmentation_transforms(images)) + return self._train_loss(outputs, ground_truths, masks) + + def validation_step(self, batch, batch_idx): + images = batch[0] + ground_truths = batch[1]["target"] + masks = batch[1]["mask"] + + outputs = self(images) + return self._validation_loss(outputs, ground_truths, masks) -- GitLab