Skip to content
Snippets Groups Projects
Commit 3f660309 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation.models] Add driu-bn model

parent bb12e640
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -423,6 +423,7 @@ visceral = "mednet.config.data.visceral.default" ...@@ -423,6 +423,7 @@ visceral = "mednet.config.data.visceral.default"
# models # models
driu = "mednet.libs.segmentation.config.models.driu" driu = "mednet.libs.segmentation.config.models.driu"
driu-bn = "mednet.libs.segmentation.config.models.driu_bn"
driu-od = "mednet.libs.segmentation.config.models.driu_od" driu-od = "mednet.libs.segmentation.config.models.driu_od"
hed = "mednet.libs.segmentation.config.models.hed" hed = "mednet.libs.segmentation.config.models.hed"
lwnet = "mednet.libs.segmentation.config.models.lwnet" lwnet = "mednet.libs.segmentation.config.models.lwnet"
......
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""DRIU Network for Vessel Segmentation.
Deep Retinal Image Understanding (DRIU), a unified framework of retinal image
analysis that provides both retinal vessel and optic disc segmentation using
deep Convolutional Neural Networks (CNNs).
Reference: [MANINIS-2016]_
"""
from mednet.libs.segmentation.engine.adabound import AdaBound
from mednet.libs.segmentation.models.driu_bn import DRIUBN
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
lr = 0.001
alpha = 0.7
betas = (0.9, 0.999)
eps = 1e-08
weight_decay = 0
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
model = DRIUBN(
loss_type=SoftJaccardBCELogitsLoss,
loss_arguments=dict(alpha=alpha),
optimizer_type=AdaBound,
optimizer_arguments=dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
),
augmentation_transforms=[],
crop_size=1024,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import typing
import torch
import torch.nn
from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
from .backbones.vgg import vgg16_for_segmentation
from .losses import SoftJaccardBCELogitsLoss
from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
from .separate import separate
logger = logging.getLogger("mednet")
class ConcatFuseBlock(torch.nn.Module):
"""Takes in four feature maps with 16 channels each, concatenates them and
applies a 1x1 convolution with 1 output channel.
"""
def __init__(self):
super().__init__()
self.conv = torch.nn.Sequential(
conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0),
torch.nn.BatchNorm2d(1),
)
def forward(self, x1, x2, x3, x4):
x_cat = torch.cat([x1, x2, x3, x4], dim=1)
return self.conv(x_cat)
class DRIUBNHead(torch.nn.Module):
"""DRIU with Batch-Normalization head module.
Based on paper by [MANINIS-2016]_.
Parameters
----------
in_channels_list : list
Number of channels for each feature map that is returned from backbone.
"""
def __init__(self, in_channels_list=None):
super().__init__()
(
in_conv_1_2_16,
in_upsample2,
in_upsample_4,
in_upsample_8,
) = in_channels_list
self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1)
# Upsample layers
self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0)
self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0)
self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0)
# Concat and Fuse
self.concatfuse = ConcatFuseBlock()
def forward(self, x):
hw = x[0]
conv1_2_16 = self.conv1_2_16(x[1]) # conv1_2_16
upsample2 = self.upsample2(x[2], hw) # side-multi2-up
upsample4 = self.upsample4(x[3], hw) # side-multi3-up
upsample8 = self.upsample8(x[4], hw) # side-multi4-up
return self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
class DRIUBN(Model):
"""Implementation of the DRIU-BN model.
Parameters
----------
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
num_classes
Number of outputs (classes) for this model.
pretrained
If True, will use VGG16 pretrained weights.
crop_size
The size of the image after center cropping.
"""
def __init__(
self,
loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
pretrained: bool = False,
crop_size: int = 544,
):
super().__init__(
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
num_classes,
)
self.name = "driu-bn"
resize_transform = ResizeMaxSide(crop_size)
self.model_transforms = [
resize_transform,
SquareCenterPad(),
]
self.pretrained = pretrained
self.backbone = vgg16_for_segmentation(
pretrained=self.pretrained,
return_features=[5, 12, 19, 29],
)
self.head = DRIUBNHead([64, 128, 256, 512])
def forward(self, x):
if self.normalizer is not None:
x = self.normalizer(x)
x = self.backbone(x)
return self.head(x)
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the normalizer for the current model.
This function is NOOP if ``pretrained = True`` (normalizer set to
imagenet weights, during contruction).
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
Will not be used if the model is pretrained.
"""
if self.pretrained:
from mednet.libs.common.models.normalizer import make_imagenet_normalizer
logger.warning(
f"ImageNet pre-trained {self.name} model - NOT "
f"computing z-norm factors from train dataloader. "
f"Using preset factors from torchvision.",
)
self.normalizer = make_imagenet_normalizer()
else:
self.normalizer = None
def training_step(self, batch, batch_idx):
images = batch[0]
ground_truths = batch[1]["target"]
masks = batch[1]["mask"]
outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx):
images = batch[0]
ground_truths = batch[1]["target"]
masks = batch[1]["mask"]
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0])[1]
probabilities = torch.sigmoid(output)
return separate((probabilities, batch[1]))
def configure_optimizers(self):
return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment