diff --git a/pyproject.toml b/pyproject.toml
index 3a548d8ab87f59c4e1a2e6cec6983b1d222102e7..0f1681ca11af5d77db441936b94570df5197f4fc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -421,6 +421,8 @@ visceral = "mednet.config.data.visceral.default"
 
 [project.entry-points."mednet.libs.segmentation.config"]
 
+# models
+driu = "mednet.libs.segmentation.config.models.driu"
 lwnet = "mednet.libs.segmentation.config.models.lwnet"
 unet = "mednet.libs.segmentation.config.models.unet"
 
diff --git a/src/mednet/libs/segmentation/config/models/driu.py b/src/mednet/libs/segmentation/config/models/driu.py
new file mode 100644
index 0000000000000000000000000000000000000000..498f6324f1eeb775742b7b9c261bbbd19ca93df7
--- /dev/null
+++ b/src/mednet/libs/segmentation/config/models/driu.py
@@ -0,0 +1,41 @@
+# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""Little W-Net for image segmentation.
+
+The Little W-Net architecture contains roughly around 70k parameters and
+closely matches (or outperforms) other more complex techniques.
+
+Reference: [GALDRAN-2020]_
+"""
+
+from mednet.libs.segmentation.engine.adabound import AdaBound
+from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
+from mednet.libs.segmentation.models.unet import Unet
+
+lr = 0.001
+alpha = 0.7
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
+
+model = Unet(
+    loss_type=SoftJaccardBCELogitsLoss,
+    loss_arguments=dict(alpha=alpha),
+    optimizer_type=AdaBound,
+    optimizer_arguments=dict(
+        lr=lr,
+        betas=betas,
+        final_lr=final_lr,
+        gamma=gamma,
+        eps=eps,
+        weight_decay=weight_decay,
+        amsbound=amsbound,
+    ),
+    augmentation_transforms=[],
+    crop_size=1024,
+)
diff --git a/src/mednet/libs/segmentation/config/models/unet.py b/src/mednet/libs/segmentation/config/models/unet.py
index 9336d2ef1e93549a1e468b460998a7eb49aec90e..498f6324f1eeb775742b7b9c261bbbd19ca93df7 100644
--- a/src/mednet/libs/segmentation/config/models/unet.py
+++ b/src/mednet/libs/segmentation/config/models/unet.py
@@ -13,7 +13,7 @@ from mednet.libs.segmentation.engine.adabound import AdaBound
 from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
 from mednet.libs.segmentation.models.unet import Unet
 
-lr = 0.01  # start
+lr = 0.001
 alpha = 0.7
 betas = (0.9, 0.999)
 eps = 1e-08
diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py
new file mode 100644
index 0000000000000000000000000000000000000000..74eba7fc6dcd66411f04f43e935fa54ab8bfa0de
--- /dev/null
+++ b/src/mednet/libs/segmentation/models/driu.py
@@ -0,0 +1,179 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+
+import logging
+import typing
+
+import torch
+import torch.nn
+from mednet.libs.common.data.typing import TransformSequence
+from mednet.libs.common.models.model import Model
+
+from .backbones.vgg import vgg16_for_segmentation
+from .losses import SoftJaccardBCELogitsLoss
+from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
+
+logger = logging.getLogger("mednet")
+
+
+class ConcatFuseBlock(torch.nn.Module):
+    """Takes in four feature maps with 16 channels each, concatenates them and
+    applies a 1x1 convolution with 1 output channel.
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.conv = conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0)
+
+    def forward(self, x1, x2, x3, x4):
+        x_cat = torch.cat([x1, x2, x3, x4], dim=1)
+        return self.conv(x_cat)
+
+
+class DRIUHead(torch.nn.Module):
+    """DRIU head module.
+
+    Based on paper by [MANINIS-2016]_.
+
+    Parameters
+    ----------
+    in_channels_list
+        Number of channels for each feature map that is returned from backbone.
+    """
+
+    def __init__(self, in_channels_list=None):
+        super().__init__()
+        (
+            in_conv_1_2_16,
+            in_upsample2,
+            in_upsample_4,
+            in_upsample_8,
+        ) = in_channels_list
+
+        self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1)
+        # Upsample layers
+        self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0)
+        self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0)
+        self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0)
+
+        # Concat and Fuse
+        self.concatfuse = ConcatFuseBlock()
+
+    def forward(self, x):
+        hw = x[0]
+        conv1_2_16 = self.conv1_2_16(x[1])  # conv1_2_16
+        upsample2 = self.upsample2(x[2], hw)  # side-multi2-up
+        upsample4 = self.upsample4(x[3], hw)  # side-multi3-up
+        upsample8 = self.upsample8(x[4], hw)  # side-multi4-up
+        return self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
+
+
+class DRIU(Model):
+    """Build DRIU for vessel segmentation by adding backbone and head
+    together.
+
+    Parameters
+    ----------
+    loss_type
+        The loss to be used for training and evaluation.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+    loss_arguments
+        Arguments to the loss.
+    optimizer_type
+        The type of optimizer to use for training.
+    optimizer_arguments
+        Arguments to the optimizer after ``params``.
+    augmentation_transforms
+        An optional sequence of torch modules containing transforms to be
+        applied on the input **before** it is fed into the network.
+    num_classes
+        Number of outputs (classes) for this model.
+    pretrained
+        If True, will use VGG16 pretrained weights.
+
+    Returns
+    -------
+    module : :py:class:`torch.nn.Module`
+        Network model for DRIU (vessel segmentation).
+    """
+
+    def __init__(
+        self,
+        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
+        optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
+        optimizer_arguments: dict[str, typing.Any] = {},
+        augmentation_transforms: TransformSequence = [],
+        num_classes: int = 1,
+        pretrained: bool = False,
+    ):
+        super().__init__(
+            loss_type,
+            loss_arguments,
+            optimizer_type,
+            optimizer_arguments,
+            augmentation_transforms,
+            num_classes,
+        )
+        self.name = "driu"
+        self.model_transforms: TransformSequence = []
+        self.pretrained = pretrained
+
+        self.backbone = vgg16_for_segmentation(
+            pretrained=self.pretrained,
+            return_features=[3, 8, 14, 22],
+        )
+
+        self.head = DRIUHead([64, 128, 256, 512])
+
+    def forward(self, x):
+        if self.normalizer is not None:
+            x = self.normalizer(x)
+        x = self.backbone(x)
+        return self.head(x)
+
+    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
+        """Initialize the normalizer for the current model.
+
+        This function is NOOP if ``pretrained = True`` (normalizer set to
+        imagenet weights, during contruction).
+
+        Parameters
+        ----------
+        dataloader
+            A torch Dataloader from which to compute the mean and std.
+            Will not be used if the model is pretrained.
+        """
+        if self.pretrained:
+            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+
+            logger.warning(
+                f"ImageNet pre-trained {self.name} model - NOT "
+                f"computing z-norm factors from train dataloader. "
+                f"Using preset factors from torchvision.",
+            )
+            self.normalizer = make_imagenet_normalizer()
+        else:
+            self.normalizer = None
+
+    def training_step(self, batch, batch_idx):
+        images = batch[0]
+        ground_truths = batch[1]["target"]
+        masks = batch[1]["mask"]
+
+        outputs = self(self._augmentation_transforms(images))
+        return self._train_loss(outputs, ground_truths, masks)
+
+    def validation_step(self, batch, batch_idx):
+        images = batch[0]
+        ground_truths = batch[1]["target"]
+        masks = batch[1]["mask"]
+
+        outputs = self(images)
+        return self._validation_loss(outputs, ground_truths, masks)