diff --git a/src/mednet/libs/segmentation/config/models/driu.py b/src/mednet/libs/segmentation/config/models/driu.py
index 498f6324f1eeb775742b7b9c261bbbd19ca93df7..847a7cd38d62c219f2663152a3462f08168a07b0 100644
--- a/src/mednet/libs/segmentation/config/models/driu.py
+++ b/src/mednet/libs/segmentation/config/models/driu.py
@@ -1,17 +1,19 @@
 # SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
-"""Little W-Net for image segmentation.
 
-The Little W-Net architecture contains roughly around 70k parameters and
-closely matches (or outperforms) other more complex techniques.
+"""DRIU Network for Vessel Segmentation.
 
-Reference: [GALDRAN-2020]_
+Deep Retinal Image Understanding (DRIU), a unified framework of retinal image
+analysis that provides both retinal vessel and optic disc segmentation using
+deep Convolutional Neural Networks (CNNs).
+
+Reference: [MANINIS-2016]_
 """
 
 from mednet.libs.segmentation.engine.adabound import AdaBound
+from mednet.libs.segmentation.models.driu import DRIU
 from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
-from mednet.libs.segmentation.models.unet import Unet
 
 lr = 0.001
 alpha = 0.7
@@ -23,7 +25,7 @@ gamma = 1e-3
 eps = 1e-8
 amsbound = False
 
-model = Unet(
+model = DRIU(
     loss_type=SoftJaccardBCELogitsLoss,
     loss_arguments=dict(alpha=alpha),
     optimizer_type=AdaBound,
diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py
index 74eba7fc6dcd66411f04f43e935fa54ab8bfa0de..3c32aeb81bcbb5dfd274ec1508f4a81266e5f109 100644
--- a/src/mednet/libs/segmentation/models/driu.py
+++ b/src/mednet/libs/segmentation/models/driu.py
@@ -10,6 +10,7 @@ import torch
 import torch.nn
 from mednet.libs.common.data.typing import TransformSequence
 from mednet.libs.common.models.model import Model
+from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
 
 from .backbones.vgg import vgg16_for_segmentation
 from .losses import SoftJaccardBCELogitsLoss
@@ -96,6 +97,8 @@ class DRIU(Model):
         Number of outputs (classes) for this model.
     pretrained
         If True, will use VGG16 pretrained weights.
+    crop_size
+        The size of the image after center cropping.
 
     Returns
     -------
@@ -112,6 +115,7 @@ class DRIU(Model):
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
         pretrained: bool = False,
+        crop_size: int = 1024,
     ):
         super().__init__(
             loss_type,
@@ -122,7 +126,12 @@ class DRIU(Model):
             num_classes,
         )
         self.name = "driu"
-        self.model_transforms: TransformSequence = []
+        resize_transform = ResizeMaxSide(crop_size)
+
+        self.model_transforms = [
+            resize_transform,
+            SquareCenterPad(),
+        ]
         self.pretrained = pretrained
 
         self.backbone = vgg16_for_segmentation(