diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py index 3c32aeb81bcbb5dfd274ec1508f4a81266e5f109..6afab13371324d32816537469e47a691889ac1c8 100644 --- a/src/mednet/libs/segmentation/models/driu.py +++ b/src/mednet/libs/segmentation/models/driu.py @@ -15,6 +15,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad from .backbones.vgg import vgg16_for_segmentation from .losses import SoftJaccardBCELogitsLoss from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform +from .separate import separate logger = logging.getLogger("mednet") @@ -72,8 +73,7 @@ class DRIUHead(torch.nn.Module): class DRIU(Model): - """Build DRIU for vessel segmentation by adding backbone and head - together. + """Implementation of the DRIU model. Parameters ---------- @@ -99,11 +99,6 @@ class DRIU(Model): If True, will use VGG16 pretrained weights. crop_size The size of the image after center cropping. - - Returns - ------- - module : :py:class:`torch.nn.Module` - Network model for DRIU (vessel segmentation). """ def __init__( @@ -186,3 +181,11 @@ class DRIU(Model): outputs = self(images) return self._validation_loss(outputs, ground_truths, masks) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + output = self(batch[0])[1] + probabilities = torch.sigmoid(output) + return separate((probabilities, batch[1])) + + def configure_optimizers(self): + return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py index bbefa8a44501219ff500a2092b833d83de55f414..7ac25fb45980b8997b76161a317c59abbba197fb 100644 --- a/src/mednet/libs/segmentation/models/hed.py +++ b/src/mednet/libs/segmentation/models/hed.py @@ -14,6 +14,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad from .backbones.vgg import vgg16_for_segmentation from .losses import MultiSoftJaccardBCELogitsLoss from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform +from .separate import separate logger = logging.getLogger("mednet") @@ -186,3 +187,11 @@ class HED(Model): outputs = self(images) return self._validation_loss(outputs, ground_truths, masks) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + output = self(batch[0])[1] + probabilities = torch.sigmoid(output) + return separate((probabilities, batch[1])) + + def configure_optimizers(self): + return self._optimizer_type(self.parameters(), **self._optimizer_arguments) diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py index 23ea8878fb58f5299263d8b60dc1a78f1cdae2ac..cc478a3c5f7d19709a8f849db1b37db520ac8456 100644 --- a/src/mednet/libs/segmentation/models/unet.py +++ b/src/mednet/libs/segmentation/models/unet.py @@ -14,6 +14,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad from .backbones.vgg import vgg16_for_segmentation from .losses import SoftJaccardBCELogitsLoss from .make_layers import UnetBlock, conv_with_kaiming_uniform +from .separate import separate logger = logging.getLogger("mednet") @@ -175,3 +176,11 @@ class Unet(Model): outputs = self(images) return self._validation_loss(outputs, ground_truths, masks) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + output = self(batch[0])[1] + probabilities = torch.sigmoid(output) + return separate((probabilities, batch[1])) + + def configure_optimizers(self): + return self._optimizer_type(self.parameters(), **self._optimizer_arguments)