Skip to content
Snippets Groups Projects
Commit 91fa4ecf authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation.models] Fix DRIU config

parent f31df6a2
No related branches found
No related tags found
1 merge request!46Create common library
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Little W-Net for image segmentation.
The Little W-Net architecture contains roughly around 70k parameters and
closely matches (or outperforms) other more complex techniques.
"""DRIU Network for Vessel Segmentation.
Reference: [GALDRAN-2020]_
Deep Retinal Image Understanding (DRIU), a unified framework of retinal image
analysis that provides both retinal vessel and optic disc segmentation using
deep Convolutional Neural Networks (CNNs).
Reference: [MANINIS-2016]_
"""
from mednet.libs.segmentation.engine.adabound import AdaBound
from mednet.libs.segmentation.models.driu import DRIU
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
from mednet.libs.segmentation.models.unet import Unet
lr = 0.001
alpha = 0.7
......@@ -23,7 +25,7 @@ gamma = 1e-3
eps = 1e-8
amsbound = False
model = Unet(
model = DRIU(
loss_type=SoftJaccardBCELogitsLoss,
loss_arguments=dict(alpha=alpha),
optimizer_type=AdaBound,
......
......@@ -10,6 +10,7 @@ import torch
import torch.nn
from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
from .backbones.vgg import vgg16_for_segmentation
from .losses import SoftJaccardBCELogitsLoss
......@@ -96,6 +97,8 @@ class DRIU(Model):
Number of outputs (classes) for this model.
pretrained
If True, will use VGG16 pretrained weights.
crop_size
The size of the image after center cropping.
Returns
-------
......@@ -112,6 +115,7 @@ class DRIU(Model):
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
pretrained: bool = False,
crop_size: int = 1024,
):
super().__init__(
loss_type,
......@@ -122,7 +126,12 @@ class DRIU(Model):
num_classes,
)
self.name = "driu"
self.model_transforms: TransformSequence = []
resize_transform = ResizeMaxSide(crop_size)
self.model_transforms = [
resize_transform,
SquareCenterPad(),
]
self.pretrained = pretrained
self.backbone = vgg16_for_segmentation(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment