Skip to content
Snippets Groups Projects
Commit b0688635 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmetnation.models] Add missing functions

parent 8e7ce463
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -15,6 +15,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
from .backbones.vgg import vgg16_for_segmentation
from .losses import SoftJaccardBCELogitsLoss
from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
from .separate import separate
logger = logging.getLogger("mednet")
......@@ -72,8 +73,7 @@ class DRIUHead(torch.nn.Module):
class DRIU(Model):
"""Build DRIU for vessel segmentation by adding backbone and head
together.
"""Implementation of the DRIU model.
Parameters
----------
......@@ -99,11 +99,6 @@ class DRIU(Model):
If True, will use VGG16 pretrained weights.
crop_size
The size of the image after center cropping.
Returns
-------
module : :py:class:`torch.nn.Module`
Network model for DRIU (vessel segmentation).
"""
def __init__(
......@@ -186,3 +181,11 @@ class DRIU(Model):
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0])[1]
probabilities = torch.sigmoid(output)
return separate((probabilities, batch[1]))
def configure_optimizers(self):
return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
......@@ -14,6 +14,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
from .backbones.vgg import vgg16_for_segmentation
from .losses import MultiSoftJaccardBCELogitsLoss
from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
from .separate import separate
logger = logging.getLogger("mednet")
......@@ -186,3 +187,11 @@ class HED(Model):
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0])[1]
probabilities = torch.sigmoid(output)
return separate((probabilities, batch[1]))
def configure_optimizers(self):
return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
......@@ -14,6 +14,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
from .backbones.vgg import vgg16_for_segmentation
from .losses import SoftJaccardBCELogitsLoss
from .make_layers import UnetBlock, conv_with_kaiming_uniform
from .separate import separate
logger = logging.getLogger("mednet")
......@@ -175,3 +176,11 @@ class Unet(Model):
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0])[1]
probabilities = torch.sigmoid(output)
return separate((probabilities, batch[1]))
def configure_optimizers(self):
return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment