From b06886359c9b4ed4d451f56471c92a148530797c Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 11 Jun 2024 12:24:51 +0200 Subject: [PATCH] [segmetnation.models] Add missing functions --- src/mednet/libs/segmentation/models/driu.py | 17 ++++++++++------- src/mednet/libs/segmentation/models/hed.py | 9 +++++++++ src/mednet/libs/segmentation/models/unet.py | 9 +++++++++ 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py index 3c32aeb8..6afab133 100644 --- a/src/mednet/libs/segmentation/models/driu.py +++ b/src/mednet/libs/segmentation/models/driu.py @@ -15,6 +15,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad from .backbones.vgg import vgg16_for_segmentation from .losses import SoftJaccardBCELogitsLoss from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform +from .separate import separate logger = logging.getLogger("mednet") @@ -72,8 +73,7 @@ class DRIUHead(torch.nn.Module): class DRIU(Model): - """Build DRIU for vessel segmentation by adding backbone and head - together. + """Implementation of the DRIU model. Parameters ---------- @@ -99,11 +99,6 @@ class DRIU(Model): If True, will use VGG16 pretrained weights. crop_size The size of the image after center cropping. - - Returns - ------- - module : :py:class:`torch.nn.Module` - Network model for DRIU (vessel segmentation). """ def __init__( @@ -186,3 +181,11 @@ class DRIU(Model): outputs = self(images) return self._validation_loss(outputs, ground_truths, masks) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + output = self(batch[0])[1] + probabilities = torch.sigmoid(output) + return separate((probabilities, batch[1])) + + def configure_optimizers(self): + return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py index bbefa8a4..7ac25fb4 100644 --- a/src/mednet/libs/segmentation/models/hed.py +++ b/src/mednet/libs/segmentation/models/hed.py @@ -14,6 +14,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad from .backbones.vgg import vgg16_for_segmentation from .losses import MultiSoftJaccardBCELogitsLoss from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform +from .separate import separate logger = logging.getLogger("mednet") @@ -186,3 +187,11 @@ class HED(Model): outputs = self(images) return self._validation_loss(outputs, ground_truths, masks) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + output = self(batch[0])[1] + probabilities = torch.sigmoid(output) + return separate((probabilities, batch[1])) + + def configure_optimizers(self): + return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py index 23ea8878..cc478a3c 100644 --- a/src/mednet/libs/segmentation/models/unet.py +++ b/src/mednet/libs/segmentation/models/unet.py @@ -14,6 +14,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad from .backbones.vgg import vgg16_for_segmentation from .losses import SoftJaccardBCELogitsLoss from .make_layers import UnetBlock, conv_with_kaiming_uniform +from .separate import separate logger = logging.getLogger("mednet") @@ -175,3 +176,11 @@ class Unet(Model): outputs = self(images) return self._validation_loss(outputs, ground_truths, masks) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + output = self(batch[0])[1] + probabilities = torch.sigmoid(output) + return separate((probabilities, batch[1])) + + def configure_optimizers(self): + return self._optimizer_type(self.parameters(), **self._optimizer_arguments) -- GitLab