diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py
index 3c32aeb81bcbb5dfd274ec1508f4a81266e5f109..6afab13371324d32816537469e47a691889ac1c8 100644
--- a/src/mednet/libs/segmentation/models/driu.py
+++ b/src/mednet/libs/segmentation/models/driu.py
@@ -15,6 +15,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
 from .backbones.vgg import vgg16_for_segmentation
 from .losses import SoftJaccardBCELogitsLoss
 from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
+from .separate import separate
 
 logger = logging.getLogger("mednet")
 
@@ -72,8 +73,7 @@ class DRIUHead(torch.nn.Module):
 
 
 class DRIU(Model):
-    """Build DRIU for vessel segmentation by adding backbone and head
-    together.
+    """Implementation of the DRIU model.
 
     Parameters
     ----------
@@ -99,11 +99,6 @@ class DRIU(Model):
         If True, will use VGG16 pretrained weights.
     crop_size
         The size of the image after center cropping.
-
-    Returns
-    -------
-    module : :py:class:`torch.nn.Module`
-        Network model for DRIU (vessel segmentation).
     """
 
     def __init__(
@@ -186,3 +181,11 @@ class DRIU(Model):
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
+
+    def predict_step(self, batch, batch_idx, dataloader_idx=0):
+        output = self(batch[0])[1]
+        probabilities = torch.sigmoid(output)
+        return separate((probabilities, batch[1]))
+
+    def configure_optimizers(self):
+        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py
index bbefa8a44501219ff500a2092b833d83de55f414..7ac25fb45980b8997b76161a317c59abbba197fb 100644
--- a/src/mednet/libs/segmentation/models/hed.py
+++ b/src/mednet/libs/segmentation/models/hed.py
@@ -14,6 +14,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
 from .backbones.vgg import vgg16_for_segmentation
 from .losses import MultiSoftJaccardBCELogitsLoss
 from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
+from .separate import separate
 
 logger = logging.getLogger("mednet")
 
@@ -186,3 +187,11 @@ class HED(Model):
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
+
+    def predict_step(self, batch, batch_idx, dataloader_idx=0):
+        output = self(batch[0])[1]
+        probabilities = torch.sigmoid(output)
+        return separate((probabilities, batch[1]))
+
+    def configure_optimizers(self):
+        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py
index 23ea8878fb58f5299263d8b60dc1a78f1cdae2ac..cc478a3c5f7d19709a8f849db1b37db520ac8456 100644
--- a/src/mednet/libs/segmentation/models/unet.py
+++ b/src/mednet/libs/segmentation/models/unet.py
@@ -14,6 +14,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
 from .backbones.vgg import vgg16_for_segmentation
 from .losses import SoftJaccardBCELogitsLoss
 from .make_layers import UnetBlock, conv_with_kaiming_uniform
+from .separate import separate
 
 logger = logging.getLogger("mednet")
 
@@ -175,3 +176,11 @@ class Unet(Model):
 
         outputs = self(images)
         return self._validation_loss(outputs, ground_truths, masks)
+
+    def predict_step(self, batch, batch_idx, dataloader_idx=0):
+        output = self(batch[0])[1]
+        probabilities = torch.sigmoid(output)
+        return separate((probabilities, batch[1]))
+
+    def configure_optimizers(self):
+        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)