diff --git a/src/mednet/libs/segmentation/config/models/driu.py b/src/mednet/libs/segmentation/config/models/driu.py index 498f6324f1eeb775742b7b9c261bbbd19ca93df7..847a7cd38d62c219f2663152a3462f08168a07b0 100644 --- a/src/mednet/libs/segmentation/config/models/driu.py +++ b/src/mednet/libs/segmentation/config/models/driu.py @@ -1,17 +1,19 @@ # SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Little W-Net for image segmentation. -The Little W-Net architecture contains roughly around 70k parameters and -closely matches (or outperforms) other more complex techniques. +"""DRIU Network for Vessel Segmentation. -Reference: [GALDRAN-2020]_ +Deep Retinal Image Understanding (DRIU), a unified framework of retinal image +analysis that provides both retinal vessel and optic disc segmentation using +deep Convolutional Neural Networks (CNNs). + +Reference: [MANINIS-2016]_ """ from mednet.libs.segmentation.engine.adabound import AdaBound +from mednet.libs.segmentation.models.driu import DRIU from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss -from mednet.libs.segmentation.models.unet import Unet lr = 0.001 alpha = 0.7 @@ -23,7 +25,7 @@ gamma = 1e-3 eps = 1e-8 amsbound = False -model = Unet( +model = DRIU( loss_type=SoftJaccardBCELogitsLoss, loss_arguments=dict(alpha=alpha), optimizer_type=AdaBound, diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py index 74eba7fc6dcd66411f04f43e935fa54ab8bfa0de..3c32aeb81bcbb5dfd274ec1508f4a81266e5f109 100644 --- a/src/mednet/libs/segmentation/models/driu.py +++ b/src/mednet/libs/segmentation/models/driu.py @@ -10,6 +10,7 @@ import torch import torch.nn from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.models.model import Model +from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad from .backbones.vgg import vgg16_for_segmentation from .losses import SoftJaccardBCELogitsLoss @@ -96,6 +97,8 @@ class DRIU(Model): Number of outputs (classes) for this model. pretrained If True, will use VGG16 pretrained weights. + crop_size + The size of the image after center cropping. Returns ------- @@ -112,6 +115,7 @@ class DRIU(Model): augmentation_transforms: TransformSequence = [], num_classes: int = 1, pretrained: bool = False, + crop_size: int = 1024, ): super().__init__( loss_type, @@ -122,7 +126,12 @@ class DRIU(Model): num_classes, ) self.name = "driu" - self.model_transforms: TransformSequence = [] + resize_transform = ResizeMaxSide(crop_size) + + self.model_transforms = [ + resize_transform, + SquareCenterPad(), + ] self.pretrained = pretrained self.backbone = vgg16_for_segmentation(