diff --git a/pyproject.toml b/pyproject.toml
index 11c85b858868c9e8f0f2d27e67b607a4b2b7bebb..3a548d8ab87f59c4e1a2e6cec6983b1d222102e7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -422,6 +422,7 @@ visceral = "mednet.config.data.visceral.default"
 [project.entry-points."mednet.libs.segmentation.config"]
 
 lwnet = "mednet.libs.segmentation.config.models.lwnet"
+unet = "mednet.libs.segmentation.config.models.unet"
 
 # chase-db1 - retinography
 chasedb1 = "mednet.libs.segmentation.config.data.chasedb1.first_annotator"
@@ -546,6 +547,7 @@ exclude = [ # don't report on objects that match any of these regex
   '\.__len__$',
   '\.__getitem__$',
   '\.__iter__$',
+  '\.__setstate__$',
   '\.__exit__$',
 ]
 
diff --git a/src/mednet/libs/segmentation/config/models/unet.py b/src/mednet/libs/segmentation/config/models/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9336d2ef1e93549a1e468b460998a7eb49aec90e
--- /dev/null
+++ b/src/mednet/libs/segmentation/config/models/unet.py
@@ -0,0 +1,41 @@
+# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""Little W-Net for image segmentation.
+
+The Little W-Net architecture contains roughly around 70k parameters and
+closely matches (or outperforms) other more complex techniques.
+
+Reference: [GALDRAN-2020]_
+"""
+
+from mednet.libs.segmentation.engine.adabound import AdaBound
+from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
+from mednet.libs.segmentation.models.unet import Unet
+
+lr = 0.01  # start
+alpha = 0.7
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
+
+model = Unet(
+    loss_type=SoftJaccardBCELogitsLoss,
+    loss_arguments=dict(alpha=alpha),
+    optimizer_type=AdaBound,
+    optimizer_arguments=dict(
+        lr=lr,
+        betas=betas,
+        final_lr=final_lr,
+        gamma=gamma,
+        eps=eps,
+        weight_decay=weight_decay,
+        amsbound=amsbound,
+    ),
+    augmentation_transforms=[],
+    crop_size=1024,
+)
diff --git a/src/mednet/libs/segmentation/engine/adabound.py b/src/mednet/libs/segmentation/engine/adabound.py
new file mode 100644
index 0000000000000000000000000000000000000000..19f6261e1218fb3cca73555403673d38cea7d2a5
--- /dev/null
+++ b/src/mednet/libs/segmentation/engine/adabound.py
@@ -0,0 +1,314 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Implementation of the AdaBound optimizer.
+
+<https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py>::
+
+    @inproceedings{Luo2019AdaBound,
+      author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu},
+      title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate},
+      booktitle = {Proceedings of the 7th International Conference on Learning Representations},
+      month = {May},
+      year = {2019},
+      address = {New Orleans, Louisiana}
+    }
+"""
+
+import math
+import typing
+
+import torch
+import torch.optim
+
+
+class AdaBound(torch.optim.Optimizer):
+    """Implement the AdaBound algorithm.
+
+    Parameters
+    ----------
+    params
+        Iterable of parameters to optimize or dicts defining parameter groups.
+    lr
+        Adam learning rate.
+    betas
+        Coefficients (as a 2-tuple of floats) used for computing running
+        averages of gradient and its square.
+    final_lr
+        Final (SGD) learning rate.
+    gamma
+        Convergence speed of the bound functions.
+    eps
+        Term added to the denominator to improve numerical stability.
+    weight_decay
+        Weight decay (L2 penalty).
+    amsbound
+        Whether to use the AMSBound variant of this algorithm.
+    """
+
+    def __init__(
+        self,
+        params: list,
+        lr: float | None = 1e-3,
+        betas: tuple[float, float] | None = (0.9, 0.999),
+        final_lr: float | None = 0.1,
+        gamma: float | None = 1e-3,
+        eps: float | None = 1e-8,
+        weight_decay: float | None = 0,
+        amsbound: bool | None = False,
+    ):
+        if not 0.0 <= lr:
+            raise ValueError(f"Invalid learning rate: {lr}")
+        if not 0.0 <= eps:
+            raise ValueError(f"Invalid epsilon value: {eps}")
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
+        if not 0.0 <= final_lr:
+            raise ValueError(f"Invalid final learning rate: {final_lr}")
+        if not 0.0 <= gamma < 1.0:
+            raise ValueError(f"Invalid gamma parameter: {gamma}")
+        defaults = dict(
+            lr=lr,
+            betas=betas,
+            final_lr=final_lr,
+            gamma=gamma,
+            eps=eps,
+            weight_decay=weight_decay,
+            amsbound=amsbound,
+        )
+        super().__init__(params, defaults)
+
+        self.base_lrs = list(map(lambda group: group["lr"], self.param_groups))
+
+    def __setstate__(self, state):
+        super().__setstate__(state)
+        for group in self.param_groups:
+            group.setdefault("amsbound", False)
+
+    def step(self, closure: typing.Callable | None = None):
+        """Perform a single optimization step.
+
+        Parameters
+        ----------
+        closure
+            A closure that reevaluates the model and returns the loss.
+
+        Returns
+        -------
+            The loss.
+        """
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group, base_lr in zip(self.param_groups, self.base_lrs):
+            for p in group["params"]:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data
+                if grad.is_sparse:
+                    raise RuntimeError(
+                        "Adam does not support sparse gradients, please consider SparseAdam instead"
+                    )
+                amsbound = group["amsbound"]
+
+                state = self.state[p]
+
+                # State initialization
+                if len(state) == 0:
+                    state["step"] = 0
+                    # Exponential moving average of gradient values
+                    state["exp_avg"] = torch.zeros_like(p.data)
+                    # Exponential moving average of squared gradient values
+                    state["exp_avg_sq"] = torch.zeros_like(p.data)
+                    if amsbound:
+                        # Maintains max of all exp. moving avg. of sq. grad. values
+                        state["max_exp_avg_sq"] = torch.zeros_like(p.data)
+
+                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
+                if amsbound:
+                    max_exp_avg_sq = state["max_exp_avg_sq"]
+                beta1, beta2 = group["betas"]
+
+                state["step"] += 1
+
+                if group["weight_decay"] != 0:
+                    grad = grad.add(group["weight_decay"], p.data)
+
+                # Decay the first and second moment running average coefficient
+                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+                if amsbound:
+                    # Maintains the maximum of all 2nd moment running avg. till now
+                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
+                    # Use the max. for normalizing running avg. of gradient
+                    denom = max_exp_avg_sq.sqrt().add_(group["eps"])
+                else:
+                    denom = exp_avg_sq.sqrt().add_(group["eps"])
+
+                bias_correction1 = 1 - beta1 ** state["step"]
+                bias_correction2 = 1 - beta2 ** state["step"]
+                step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
+
+                # Applies bounds on actual learning rate
+                # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
+                final_lr = group["final_lr"] * group["lr"] / base_lr
+                lower_bound = final_lr * (1 - 1 / (group["gamma"] * state["step"] + 1))
+                upper_bound = final_lr * (1 + 1 / (group["gamma"] * state["step"]))
+                step_size = torch.full_like(denom, step_size)
+                step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
+
+                p.data.add_(-step_size)
+
+        return loss
+
+
+class AdaBoundW(torch.optim.Optimizer):
+    """Implement AdaBound algorithm with Decoupled Weight Decay (See
+    https://arxiv.org/abs/1711.05101).
+
+    Parameters
+    ----------
+    params
+        Iterable of parameters to optimize or dicts defining parameter groups.
+    lr
+        Adam learning rate.
+    betas
+        Coefficients (as a 2-tuple of floats) used for computing running
+        averages of gradient and its square.
+    final_lr
+        Final (SGD) learning rate.
+    gamma
+        Convergence speed of the bound functions.
+    eps
+        Term added to the denominator to improve numerical stability.
+    weight_decay
+        Weight decay (L2 penalty).
+    amsbound
+        Whether to use the AMSBound variant of this algorithm.
+    """
+
+    def __init__(
+        self,
+        params: list,
+        lr: float | None = 1e-3,
+        betas: tuple[float, float] | None = (0.9, 0.999),
+        final_lr: float | None = 0.1,
+        gamma: float | None = 1e-3,
+        eps: float | None = 1e-8,
+        weight_decay: float | None = 0,
+        amsbound: bool | None = False,
+    ):
+        if not 0.0 <= lr:
+            raise ValueError(f"Invalid learning rate: {lr}")
+        if not 0.0 <= eps:
+            raise ValueError(f"Invalid epsilon value: {eps}")
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
+        if not 0.0 <= final_lr:
+            raise ValueError(f"Invalid final learning rate: {final_lr}")
+        if not 0.0 <= gamma < 1.0:
+            raise ValueError(f"Invalid gamma parameter: {gamma}")
+        defaults = dict(
+            lr=lr,
+            betas=betas,
+            final_lr=final_lr,
+            gamma=gamma,
+            eps=eps,
+            weight_decay=weight_decay,
+            amsbound=amsbound,
+        )
+        super().__init__(params, defaults)
+
+        self.base_lrs = list(map(lambda group: group["lr"], self.param_groups))
+
+    def __setstate__(self, state):
+        super().__setstate__(state)
+        for group in self.param_groups:
+            group.setdefault("amsbound", False)
+
+    def step(self, closure: typing.Callable | None = None):
+        """Perform a single optimization step.
+
+        Parameters
+        ----------
+        closure
+            A closure that reevaluates the model and returns the loss.
+
+        Returns
+        -------
+            The loss.
+        """
+
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group, base_lr in zip(self.param_groups, self.base_lrs):
+            for p in group["params"]:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data
+                if grad.is_sparse:
+                    raise RuntimeError(
+                        "Adam does not support sparse gradients, please consider SparseAdam instead"
+                    )
+                amsbound = group["amsbound"]
+
+                state = self.state[p]
+
+                # State initialization
+                if len(state) == 0:
+                    state["step"] = 0
+                    # Exponential moving average of gradient values
+                    state["exp_avg"] = torch.zeros_like(p.data)
+                    # Exponential moving average of squared gradient values
+                    state["exp_avg_sq"] = torch.zeros_like(p.data)
+                    if amsbound:
+                        # Maintains max of all exp. moving avg. of sq. grad. values
+                        state["max_exp_avg_sq"] = torch.zeros_like(p.data)
+
+                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
+                if amsbound:
+                    max_exp_avg_sq = state["max_exp_avg_sq"]
+                beta1, beta2 = group["betas"]
+
+                state["step"] += 1
+
+                # Decay the first and second moment running average coefficient
+                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+                if amsbound:
+                    # Maintains the maximum of all 2nd moment running avg. till now
+                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
+                    # Use the max. for normalizing running avg. of gradient
+                    denom = max_exp_avg_sq.sqrt().add_(group["eps"])
+                else:
+                    denom = exp_avg_sq.sqrt().add_(group["eps"])
+
+                bias_correction1 = 1 - beta1 ** state["step"]
+                bias_correction2 = 1 - beta2 ** state["step"]
+                step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
+
+                # Applies bounds on actual learning rate
+                # lr_scheduler cannot affect final_lr, this is a workaround to
+                # apply lr decay
+                final_lr = group["final_lr"] * group["lr"] / base_lr
+                lower_bound = final_lr * (1 - 1 / (group["gamma"] * state["step"] + 1))
+                upper_bound = final_lr * (1 + 1 / (group["gamma"] * state["step"]))
+                step_size = torch.full_like(denom, step_size)
+                step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
+
+                if group["weight_decay"] != 0:
+                    decayed_weights = torch.mul(p.data, group["weight_decay"])
+                    p.data.add_(-step_size)
+                    p.data.sub_(decayed_weights)
+                else:
+                    p.data.add_(-step_size)
+
+        return loss
diff --git a/src/mednet/libs/segmentation/models/backbones/__init__.py b/src/mednet/libs/segmentation/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/mednet/libs/segmentation/models/backbones/mobilenetv2.py b/src/mednet/libs/segmentation/models/backbones/mobilenetv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..0340b3b895603406a3e3d3c8965ee756a5a0edf5
--- /dev/null
+++ b/src/mednet/libs/segmentation/models/backbones/mobilenetv2.py
@@ -0,0 +1,92 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import torchvision.models
+import torchvision.models.mobilenetv2
+
+try:
+    # pytorch >= 1.12
+    from torch.hub import load_state_dict_from_url
+except ImportError:
+    # pytorch < 1.12
+    from torchvision.models.utils import load_state_dict_from_url
+
+
+class MobileNetV24Segmentation(torchvision.models.mobilenetv2.MobileNetV2):
+    """Adaptation of base MobileNetV2 functionality to U-Net style
+    segmentation.
+
+    This version of MobileNetV2 is slightly modified so it can be used through
+    torchvision's API.  It outputs intermediate features which are normally not
+    output by the base MobileNetV2 implementation, but are required for
+    segmentation operations.
+
+    Parameters
+    ----------
+    *args
+        Arguments to be passed to the parent MobileNetV2 model.
+    **kwargs
+        Keyword arguments to be passed to the parent MobileNetV2 model.
+        return_features : :py:class:`list`, Optional
+        A list of integers indicating the feature layers to be returned from
+        the original module.
+    """
+
+    def __init__(self, *args, **kwargs):
+        self._return_features = kwargs.pop("return_features")
+        super().__init__(*args, **kwargs)
+
+    def forward(self, x):
+        outputs = []
+        # hw of input, needed for DRIU and HED
+        outputs.append(x.shape[2:4])
+        outputs.append(x)
+        for index, m in enumerate(self.features):
+            x = m(x)
+            # extract layers
+            if index in self._return_features:
+                outputs.append(x)
+        return outputs
+
+
+def mobilenet_v2_for_segmentation(pretrained=False, progress=True, **kwargs):
+    """Create MobileNetV2 model for segmentation task.
+
+    Parameters
+    ----------
+    pretrained
+        If True, uses MobileNetV2 pretrained weights.
+    progress
+        If True, shows a progress bar when downloading the pretrained weights.
+    **kwargs
+        Keyword arguments to be passed to the parent MobileNetV2 model.
+        return_features : :py:class:`list`, Optional
+        A list of integers indicating the feature layers to be returned from
+        the original module.
+
+    Returns
+    -------
+        Instance of the MobileNetV2 model for segmentation.
+    """
+
+    model = MobileNetV24Segmentation(**kwargs)
+
+    if pretrained:
+        state_dict = load_state_dict_from_url(
+            torchvision.models.mobilenetv2.MobileNet_V2_Weights.DEFAULT.url,
+            progress=progress,
+        )
+        model.load_state_dict(state_dict)
+
+    # erase MobileNetV2 head (for classification), not used for segmentation
+    delattr(model, "classifier")
+
+    return_features = kwargs.get("return_features")
+    if return_features is not None:
+        model.features = model.features[: (max(return_features) + 1)]
+
+    return model
+
+
+mobilenet_v2_for_segmentation.__doc__ = torchvision.models.mobilenetv2.__doc__
diff --git a/src/mednet/libs/segmentation/models/backbones/resnet.py b/src/mednet/libs/segmentation/models/backbones/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bd2b3f35a7f9e109fad85ebdac317a30abd3791
--- /dev/null
+++ b/src/mednet/libs/segmentation/models/backbones/resnet.py
@@ -0,0 +1,87 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import torchvision.models
+
+try:
+    # pytorch >= 1.12
+    from torch.hub import load_state_dict_from_url
+except ImportError:
+    # pytorch < 1.12
+    from torchvision.models.utils import load_state_dict_from_url
+
+
+class ResNet4Segmentation(torchvision.models.resnet.ResNet):
+    """Adaptation of base ResNet functionality to U-Net style segmentation.
+
+    This version of ResNet is slightly modified so it can be used through
+    torchvision's API.  It outputs intermediate features which are normally not
+    output by the base ResNet implementation, but are required for segmentation
+    operations.
+
+    Parameters
+    ----------
+    *args
+        Arguments to be passed to the parent ResNet model.
+    **kwargs
+        Keyword arguments to be passed to the parent ResNet model.
+        return_features : :py:class:`list`, Optional
+            A list of integers indicating the feature layers to be returned from
+            the original module.
+    """
+
+    def __init__(self, *args, **kwargs):
+        self._return_features = kwargs.pop("return_features")
+        super().__init__(*args, **kwargs)
+
+    def forward(self, x):
+        outputs = []
+        # hardwiring of input
+        outputs.append(x.shape[2:4])
+        for index, m in enumerate(self.features):
+            x = m(x)
+            # extract layers
+            if index in self.return_features:
+                outputs.append(x)
+        return outputs
+
+
+def resnet50_for_segmentation(pretrained=False, progress=True, **kwargs):
+    """Create ResNet for segmentation task.
+
+    Parameters
+    ----------
+    pretrained
+        If True, uses ResNet50 pretrained weights.
+    progress
+        If True, shows a progress bar when downloading the pretrained weights.
+    **kwargs
+        Keyword arguments to be passed to the parent ResNet model.
+        return_features : :py:class:`list`, Optional
+            A list of integers indicating the feature layers to be returned from
+            the original module.
+
+    Returns
+    -------
+        Instance of the ResNet model for segmentation.
+    """
+    model = ResNet4Segmentation(
+        torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], **kwargs
+    )
+
+    if pretrained:
+        state_dict = load_state_dict_from_url(
+            torchvision.models.resnet.ResNet50_Weights.DEFAULT.url,
+            progress=progress,
+        )
+        model.load_state_dict(state_dict)
+
+    # erase ResNet head (for classification), not used for segmentation
+    delattr(model, "avgpool")
+    delattr(model, "fc")
+
+    return model
+
+
+resnet50_for_segmentation.__doc__ = torchvision.models.resnet50.__doc__
diff --git a/src/mednet/libs/segmentation/models/backbones/vgg.py b/src/mednet/libs/segmentation/models/backbones/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8ed74ae0589daca96019daee33801021d7c30e2
--- /dev/null
+++ b/src/mednet/libs/segmentation/models/backbones/vgg.py
@@ -0,0 +1,130 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import torchvision.models
+
+try:
+    # pytorch >= 1.12
+    from torch.hub import load_state_dict_from_url
+except ImportError:
+    # pytorch < 1.12
+    from torchvision.models.utils import load_state_dict_from_url
+
+
+class VGG4Segmentation(torchvision.models.vgg.VGG):
+    """Adaptation of base VGG functionality to U-Net style segmentation.
+
+    This version of VGG is slightly modified so it can be used through
+    torchvision's API.  It outputs intermediate features which are normally not
+    output by the base VGG implementation, but are required for segmentation
+    operations.
+
+    Parameters
+    ----------
+    *args
+        Arguments to be passed to the parent VGG model.
+    **kwargs
+        Keyword arguments to be passed to the parent VGG model.
+        return_features : :py:class:`list`, Optional
+            A list of integers indicating the feature layers to be returned from
+            the original module.
+    """
+
+    def __init__(self, *args, **kwargs):
+        self._return_features = kwargs.pop("return_features")
+        super().__init__(*args, **kwargs)
+
+    def forward(self, x):
+        outputs = []
+        # hardwiring of input
+        outputs.append(x.shape[2:4])
+        for index, m in enumerate(self.features):
+            x = m(x)
+            # extract layers
+            if index in self._return_features:
+                outputs.append(x)
+        return outputs
+
+
+def _make_vgg16_type_d_for_segmentation(pretrained, batch_norm, progress, **kwargs):
+    if pretrained:
+        kwargs["init_weights"] = False
+
+    model = VGG4Segmentation(
+        torchvision.models.vgg.make_layers(
+            torchvision.models.vgg.cfgs["D"],
+            batch_norm=batch_norm,
+        ),
+        **kwargs,
+    )
+
+    if pretrained:
+        weights = (
+            torchvision.models.vgg.VGG16_Weights.DEFAULT.url
+            if not batch_norm
+            else torchvision.models.vgg.VGG16_BN_Weights.DEFAULT.url
+        )
+
+        state_dict = load_state_dict_from_url(weights, progress=progress)
+        model.load_state_dict(state_dict)
+
+    # erase VGG head (for classification), not used for segmentation
+    delattr(model, "classifier")
+    delattr(model, "avgpool")
+
+    return model
+
+
+def vgg16_for_segmentation(pretrained=False, progress=True, **kwargs):
+    """Create an instance of VGG16.
+
+    Parameters
+    ----------
+    pretrained
+        If True, usees VGG16 pretrained weights.
+    progress
+        If True, shows a progress bar when downloading weights.
+    **kwargs
+        Keyword arguments to be passed to the parent VGG model.
+        return_features : :py:class:`list`, Optional
+            A list of integers indicating the feature layers to be returned from
+            the original module.
+
+    Returns
+    -------
+        Instance of VGG16.
+    """
+    return _make_vgg16_type_d_for_segmentation(
+        pretrained=pretrained, batch_norm=False, progress=progress, **kwargs
+    )
+
+
+vgg16_for_segmentation.__doc__ = torchvision.models.vgg16.__doc__
+
+
+def vgg16_bn_for_segmentation(pretrained=False, progress=True, **kwargs):
+    """Create an instance of VGG16 with batch norm.
+
+    Parameters
+    ----------
+    pretrained
+        If True, usees VGG16 pretrained weights.
+    progress
+        If True, shows a progress bar when downloading weights.
+    **kwargs
+        Keyword arguments to be passed to the parent VGG model.
+        return_features : :py:class:`list`, Optional
+            A list of integers indicating the feature layers to be returned from
+            the original module.
+
+    Returns
+    -------
+        Instance of VGG16.
+    """
+    return _make_vgg16_type_d_for_segmentation(
+        pretrained=pretrained, batch_norm=True, progress=progress, **kwargs
+    )
+
+
+vgg16_bn_for_segmentation.__doc__ = torchvision.models.vgg16_bn.__doc__
diff --git a/src/mednet/libs/segmentation/models/make_layers.py b/src/mednet/libs/segmentation/models/make_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b3228b87740da102051f4b8e70973f3db5ba7b2
--- /dev/null
+++ b/src/mednet/libs/segmentation/models/make_layers.py
@@ -0,0 +1,273 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import torch
+import torch.nn
+from torch.nn import Conv2d, ConvTranspose2d
+
+
+def conv_with_kaiming_uniform(
+    in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
+):
+    """Convolution layer with kaiming uniform.
+
+    Parameters
+    ----------
+    in_channels
+        Number of input channels.
+    out_channels
+        Number of output channels.
+    kernel_size
+        The kernel size.
+    stride
+        The stride.
+    padding
+        The padding.
+    dilation
+        The dilation.
+
+    Returns
+    -------
+        The convoluation layer.
+    """
+    conv = Conv2d(
+        in_channels,
+        out_channels,
+        kernel_size=kernel_size,
+        stride=stride,
+        padding=padding,
+        dilation=dilation,
+        bias=True,
+    )
+    # Caffe2 implementation uses XavierFill, which in fact
+    # corresponds to kaiming_uniform_ in PyTorch
+    torch.nn.init.kaiming_uniform_(conv.weight, a=1)
+    torch.nn.init.constant_(conv.bias, 0)
+    return conv
+
+
+def convtrans_with_kaiming_uniform(
+    in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
+):
+    """Implement convtrans layer with kaiming uniform.
+
+    Parameters
+    ----------
+    in_channels
+        Number of input channels.
+    out_channels
+        Number of output channels.
+    kernel_size
+        The kernel size.
+    stride
+        The stride.
+    padding
+        The padding.
+    dilation
+        The dilation.
+
+    Returns
+    -------
+        The convtrans layer.
+    """
+    conv = ConvTranspose2d(
+        in_channels,
+        out_channels,
+        kernel_size=kernel_size,
+        stride=stride,
+        padding=padding,
+        dilation=dilation,
+        bias=True,
+    )
+    # Caffe2 implementation uses XavierFill, which in fact
+    # corresponds to kaiming_uniform_ in PyTorch
+    torch.nn.init.kaiming_uniform_(conv.weight, a=1)
+    torch.nn.init.constant_(conv.bias, 0)
+    return conv
+
+
+class UpsampleCropBlock(torch.nn.Module):
+    """Combines Conv2d, ConvTransposed2d and Cropping. Simulates the caffe2
+    crop layer in the forward function.
+
+    Used for DRIU and HED.
+
+    Parameters
+    ----------
+    in_channels
+        Number of channels of intermediate layer.
+    out_channels
+        Number of output channels.
+    up_kernel_size
+        Kernel size for transposed convolution.
+    up_stride
+        Stride for transposed convolution.
+    up_padding
+        Padding for transposed convolution.
+    pixelshuffle
+        If True, uses PixelShuffleICNR upsampling.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        up_kernel_size: int,
+        up_stride: int,
+        up_padding: int,
+        pixelshuffle: bool = False,
+    ):
+        super().__init__()
+        # NOTE: Kaiming init, replace with torch.nn.Conv2d and torch.nn.ConvTranspose2d to get original DRIU impl.
+        self.conv = conv_with_kaiming_uniform(in_channels, out_channels, 3, 1, 1)
+        if pixelshuffle:
+            self.upconv = PixelShuffleICNR(out_channels, out_channels, scale=up_stride)
+        else:
+            self.upconv = convtrans_with_kaiming_uniform(
+                out_channels,
+                out_channels,
+                up_kernel_size,
+                up_stride,
+                up_padding,
+            )
+
+    def forward(self, x, input_res):
+        img_h = input_res[0]
+        img_w = input_res[1]
+        x = self.conv(x)
+        x = self.upconv(x)
+        # determine center crop
+        # height
+        up_h = x.shape[2]
+        h_crop = up_h - img_h
+        h_s = h_crop // 2
+        h_e = up_h - (h_crop - h_s)
+        # width
+        up_w = x.shape[3]
+        w_crop = up_w - img_w
+        w_s = w_crop // 2
+        w_e = up_w - (w_crop - w_s)
+        # perform crop
+        # needs explicit ranges for onnx export
+        return x[:, :, h_s:h_e, w_s:w_e]  # crop to input size
+
+
+def ifnone(a, b):
+    """Return ``a`` if ``a`` is not None, otherwise ``b``.
+
+    Parameters
+    ----------
+    a
+        The first parameter.
+    b
+        The second parameter.
+
+    Returns
+    -------
+        The parameter a if it is not None, else b.
+    """
+    return b if a is None else a
+
+
+def icnr(x, scale=2, init=torch.nn.init.kaiming_normal_):
+    """ICNR init of ``x``, with ``scale`` and ``init`` function.
+
+    https://docs.fast.ai/layers.html#PixelShuffleICNR.
+
+    Parameters
+    ----------
+    x
+        Tensor.
+    scale
+        Scale of the upsample.
+    init
+        Function used to initialize.
+    """
+
+    ni, nf, h, w = x.shape
+    ni2 = int(ni / (scale**2))
+    k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1)
+    k = k.contiguous().view(ni2, nf, -1)
+    k = k.repeat(1, 1, scale**2)
+    k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
+    x.data.copy_(k)
+
+
+class PixelShuffleICNR(torch.nn.Module):
+    """Upsample by ``scale`` from ``ni`` filters to ``nf`` (default
+    ``ni``), using ``torch.nn.PixelShuffle``, ``icnr`` init, and
+    ``weight_norm``.
+
+    https://docs.fast.ai/layers.html#PixelShuffleICNR.
+
+    Parameters
+    ----------
+    ni
+        Number of initial filters.
+    nf
+        Number of final filters.
+    scale
+        Scale of the upsample.
+    """
+
+    def __init__(self, ni: int, nf: int = None, scale: int = 2):
+        super().__init__()
+        nf = ifnone(nf, ni)
+        self.conv = conv_with_kaiming_uniform(ni, nf * (scale**2), 1)
+        icnr(self.conv.weight)
+        self.shuf = torch.nn.PixelShuffle(scale)
+        # Blurring over (h*w) kernel
+        # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
+        # - https://arxiv.org/abs/1806.02658
+        self.pad = torch.nn.ReplicationPad2d((1, 0, 1, 0))
+        self.blur = torch.nn.AvgPool2d(2, stride=1)
+        self.relu = torch.nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        x = self.shuf(self.relu(self.conv(x)))
+        return self.blur(self.pad(x))
+
+
+class UnetBlock(torch.nn.Module):
+    """Unet block implementation.
+
+    Parameters
+    ----------
+    up_in_c
+        Number of input channels.
+    x_in_c
+        Number of cat channels.
+    pixel_shuffle
+        If True, uses a PixelShuffleICNR layer for upsampling.
+    middle_block
+        If True, uses a middle block for VGG based U-Net.
+    """
+
+    def __init__(self, up_in_c, x_in_c, pixel_shuffle=False, middle_block=False):
+        super().__init__()
+        # middle block for VGG based U-Net
+        if middle_block:
+            up_out_c = up_in_c
+        else:
+            up_out_c = up_in_c // 2
+        cat_channels = x_in_c + up_out_c
+        inner_channels = cat_channels // 2
+
+        if pixel_shuffle:
+            self.upsample = PixelShuffleICNR(up_in_c, up_out_c)
+        else:
+            self.upsample = convtrans_with_kaiming_uniform(up_in_c, up_out_c, 2, 2)
+        self.convtrans1 = convtrans_with_kaiming_uniform(
+            cat_channels, inner_channels, 3, 1, 1
+        )
+        self.convtrans2 = convtrans_with_kaiming_uniform(
+            inner_channels, inner_channels, 3, 1, 1
+        )
+        self.relu = torch.nn.ReLU(inplace=True)
+
+    def forward(self, up_in, x_in):
+        up_out = self.upsample(up_in)
+        cat_x = torch.cat([up_out, x_in], dim=1)
+        x = self.relu(self.convtrans1(cat_x))
+        return self.relu(self.convtrans2(x))
diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..23ea8878fb58f5299263d8b60dc1a78f1cdae2ac
--- /dev/null
+++ b/src/mednet/libs/segmentation/models/unet.py
@@ -0,0 +1,177 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+
+import logging
+import typing
+
+import torch.nn
+from mednet.libs.common.data.typing import TransformSequence
+from mednet.libs.common.models.model import Model
+from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
+
+from .backbones.vgg import vgg16_for_segmentation
+from .losses import SoftJaccardBCELogitsLoss
+from .make_layers import UnetBlock, conv_with_kaiming_uniform
+
+logger = logging.getLogger("mednet")
+
+
+class UNetHead(torch.nn.Module):
+    """UNet head module.
+
+    Parameters
+    ----------
+    in_channels_list
+        Number of channels for each feature map that is returned from backbone.
+    pixel_shuffle
+        If True, upsample using PixelShuffleICNR.
+    """
+
+    def __init__(self, in_channels_list: list[int] = None, pixel_shuffle=False):
+        super().__init__()
+        # number of channels
+        c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list
+
+        # build layers
+        self.decode4 = UnetBlock(c_decode5, c_decode4, pixel_shuffle, middle_block=True)
+        self.decode3 = UnetBlock(c_decode4, c_decode3, pixel_shuffle)
+        self.decode2 = UnetBlock(c_decode3, c_decode2, pixel_shuffle)
+        self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle)
+        self.final = conv_with_kaiming_uniform(c_decode1, 1, 1)
+
+    def forward(self, x: list[torch.Tensor]):
+        """Forward pass.
+
+        Parameters
+        ----------
+        x
+            List of tensors as returned from the backbone network.
+            First element: height and width of input image.
+            Remaining elements: feature maps for each feature level.
+
+        Returns
+        -------
+            OUtput of the forward pass.
+        """
+        # NOTE: x[0]: height and width of input image not needed in U-Net architecture
+        decode4 = self.decode4(x[5], x[4])
+        decode3 = self.decode3(decode4, x[3])
+        decode2 = self.decode2(decode3, x[2])
+        decode1 = self.decode1(decode2, x[1])
+        return self.final(decode1)
+
+
+class Unet(Model):
+    """Implementation of the Unet model.
+
+    Parameters
+    ----------
+    loss_type
+        The loss to be used for training and evaluation.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+    loss_arguments
+        Arguments to the loss.
+    optimizer_type
+        The type of optimizer to use for training.
+    optimizer_arguments
+        Arguments to the optimizer after ``params``.
+    augmentation_transforms
+        An optional sequence of torch modules containing transforms to be
+        applied on the input **before** it is fed into the network.
+    num_classes
+        Number of outputs (classes) for this model.
+    pretrained
+        If True, will use VGG16 pretrained weights.
+    crop_size
+        The size of the image after center cropping.
+    """
+
+    def __init__(
+        self,
+        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
+        optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
+        optimizer_arguments: dict[str, typing.Any] = {},
+        augmentation_transforms: TransformSequence = [],
+        num_classes: int = 1,
+        pretrained: bool = False,
+        crop_size: int = 544,
+    ):
+        super().__init__(
+            loss_type,
+            loss_arguments,
+            optimizer_type,
+            optimizer_arguments,
+            augmentation_transforms,
+            num_classes,
+        )
+
+        self.name = "unet"
+
+        resize_transform = ResizeMaxSide(crop_size)
+
+        self.model_transforms = [
+            resize_transform,
+            SquareCenterPad(),
+        ]
+
+        self.pretrained = pretrained
+
+        self.backbone = vgg16_for_segmentation(
+            pretrained=self.pretrained,
+            return_features=[3, 8, 14, 22, 29],
+        )
+
+        self.head = UNetHead([64, 128, 256, 512, 512], pixel_shuffle=False)
+
+    def forward(self, x):
+        if self.normalizer is not None:
+            x = self.normalizer(x)
+        x = self.backbone(x)
+        return self.head(x)
+
+    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
+        """Initialize the normalizer for the current model.
+
+        This function is NOOP if ``pretrained = True`` (normalizer set to
+        imagenet weights, during contruction).
+
+        Parameters
+        ----------
+        dataloader
+            A torch Dataloader from which to compute the mean and std.
+            Will not be used if the model is pretrained.
+        """
+        if self.pretrained:
+            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+
+            logger.warning(
+                f"ImageNet pre-trained {self.name} model - NOT "
+                f"computing z-norm factors from train dataloader. "
+                f"Using preset factors from torchvision.",
+            )
+            self.normalizer = make_imagenet_normalizer()
+        else:
+            self.normalizer = None
+
+    def training_step(self, batch, batch_idx):
+        images = batch[0]
+        ground_truths = batch[1]["target"]
+        masks = batch[1]["mask"]
+
+        outputs = self(self._augmentation_transforms(images))
+        return self._train_loss(outputs, ground_truths, masks)
+
+    def validation_step(self, batch, batch_idx):
+        images = batch[0]
+        ground_truths = batch[1]["target"]
+        masks = batch[1]["mask"]
+
+        outputs = self(images)
+        return self._validation_loss(outputs, ground_truths, masks)