Skip to content
Snippets Groups Projects
Commit 91fa4ecf authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation.models] Fix DRIU config

parent f31df6a2
No related branches found
No related tags found
1 merge request!46Create common library
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Little W-Net for image segmentation.
The Little W-Net architecture contains roughly around 70k parameters and """DRIU Network for Vessel Segmentation.
closely matches (or outperforms) other more complex techniques.
Reference: [GALDRAN-2020]_ Deep Retinal Image Understanding (DRIU), a unified framework of retinal image
analysis that provides both retinal vessel and optic disc segmentation using
deep Convolutional Neural Networks (CNNs).
Reference: [MANINIS-2016]_
""" """
from mednet.libs.segmentation.engine.adabound import AdaBound from mednet.libs.segmentation.engine.adabound import AdaBound
from mednet.libs.segmentation.models.driu import DRIU
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
from mednet.libs.segmentation.models.unet import Unet
lr = 0.001 lr = 0.001
alpha = 0.7 alpha = 0.7
...@@ -23,7 +25,7 @@ gamma = 1e-3 ...@@ -23,7 +25,7 @@ gamma = 1e-3
eps = 1e-8 eps = 1e-8
amsbound = False amsbound = False
model = Unet( model = DRIU(
loss_type=SoftJaccardBCELogitsLoss, loss_type=SoftJaccardBCELogitsLoss,
loss_arguments=dict(alpha=alpha), loss_arguments=dict(alpha=alpha),
optimizer_type=AdaBound, optimizer_type=AdaBound,
......
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import torch.nn import torch.nn
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model from mednet.libs.common.models.model import Model
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
from .backbones.vgg import vgg16_for_segmentation from .backbones.vgg import vgg16_for_segmentation
from .losses import SoftJaccardBCELogitsLoss from .losses import SoftJaccardBCELogitsLoss
...@@ -96,6 +97,8 @@ class DRIU(Model): ...@@ -96,6 +97,8 @@ class DRIU(Model):
Number of outputs (classes) for this model. Number of outputs (classes) for this model.
pretrained pretrained
If True, will use VGG16 pretrained weights. If True, will use VGG16 pretrained weights.
crop_size
The size of the image after center cropping.
Returns Returns
------- -------
...@@ -112,6 +115,7 @@ class DRIU(Model): ...@@ -112,6 +115,7 @@ class DRIU(Model):
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
num_classes: int = 1, num_classes: int = 1,
pretrained: bool = False, pretrained: bool = False,
crop_size: int = 1024,
): ):
super().__init__( super().__init__(
loss_type, loss_type,
...@@ -122,7 +126,12 @@ class DRIU(Model): ...@@ -122,7 +126,12 @@ class DRIU(Model):
num_classes, num_classes,
) )
self.name = "driu" self.name = "driu"
self.model_transforms: TransformSequence = [] resize_transform = ResizeMaxSide(crop_size)
self.model_transforms = [
resize_transform,
SquareCenterPad(),
]
self.pretrained = pretrained self.pretrained = pretrained
self.backbone = vgg16_for_segmentation( self.backbone = vgg16_for_segmentation(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment