diff --git a/pyproject.toml b/pyproject.toml
index 3a5d57eb7f5428f65a675c9950ed5e4be87b5315..d4e1aef6e781053509e570e7f57a0808d6f51e91 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -76,7 +76,7 @@ mednet = "mednet.scripts.cli:cli"
 [tool.pixi.project]
 channels = ["conda-forge", "pytorch"]
 platforms = ["linux-64", "osx-arm64"]
-conda-pypi-map = {"https://conda.anaconda.org/pytorch" = ".pixi-pytorch-mapping.json"}
+conda-pypi-map = { "https://conda.anaconda.org/pytorch" = ".pixi-pytorch-mapping.json" }
 
 [tool.pixi.system-requirements]
 linux = "4.19.0"
@@ -181,7 +181,7 @@ cuda = "12.1"
 [tool.pixi.feature.cuda.target.linux-64.dependencies]
 #cuda = { version = "*", channel = "nvidia" }
 pytorch-cuda = { version = "12.1.*", channel = "pytorch" }
-pip = "*"  # required for docker image building
+pip = "*"                                                  # required for docker image building
 
 [tool.pixi.environments]
 default = { features = ["qa", "build", "doc", "test", "dev", "py312", "self"] }
@@ -416,6 +416,13 @@ montgomery-shenzhen-indian-padchest = "mednet.libs.classification.config.data.mo
 # VISCERAL dataset
 visceral = "mednet.config.data.visceral.default"
 
+[project.entry-points."mednet.libs.segmentation.config"]
+
+lwnet = "mednet.libs.segmentation.config.models.lwnet"
+
+# drive dataset - retinal vessel segmentation
+drive = "mednet.libs.segmentation.config.data.drive.default"
+
 [tool.ruff]
 line-length = 88
 target-version = "py310"
diff --git a/src/mednet/libs/segmentation/__init__.py b/src/mednet/libs/segmentation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/mednet/libs/segmentation/config/data/drive/__init__.py b/src/mednet/libs/segmentation/config/data/drive/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecab53d8e9f3bbf598fb919ff97bd55c248cd6a4
--- /dev/null
+++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
@@ -0,0 +1,129 @@
+# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""COVD-DRIVE for Vessel Segmentation."""
+
+import importlib.resources
+import os
+from pathlib import Path
+
+import PIL.Image
+from mednet.libs.common.data.datamodule import CachingDataModule
+from mednet.libs.common.data.split import JSONDatabaseSplit
+from mednet.libs.common.data.typing import DatabaseSplit, Sample
+from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader
+from torchvision.transforms.functional import to_tensor
+
+from ....utils.rc import load_rc
+
+CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
+"""Key to search for in the configuration file for the root directory of this
+database."""
+
+
+class RawDataLoader(_BaseRawDataLoader):
+    """A specialized raw-data-loader for the Montgomery dataset."""
+
+    datadir: str
+    """This variable contains the base directory where the database raw data is
+    stored."""
+
+    def __init__(self):
+        self.datadir = load_rc().get(
+            CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)
+        )
+
+    def sample(self, sample: tuple[str, int]) -> Sample:
+        """Load a single image sample from the disk.
+
+        Parameters
+        ----------
+        sample
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+        Returns
+        -------
+            The sample representation.
+        """
+
+        image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
+            mode="RGB"
+        )
+        tensor = to_tensor(image)
+        label = PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
+            mode="1", dither=None
+        )
+        mask = PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
+            mode="1", dither=None
+        )
+        return tensor, dict(
+            label=to_tensor(label), mask=to_tensor(mask), name=sample[0]
+        )  # type: ignore[arg-type]
+
+    def label(self, sample: tuple[str, int]) -> int:
+        """Load a single image sample label from the disk.
+
+        Parameters
+        ----------
+        sample
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+        Returns
+        -------
+        int
+            The integer label associated with the sample.
+        """
+        return sample[1]
+
+
+def make_split(basename: str) -> DatabaseSplit:
+    """Return a database split for the Montgomery database.
+
+    Parameters
+    ----------
+    basename
+        Name of the .json file containing the split to load.
+
+    Returns
+    -------
+        An instance of DatabaseSplit.
+    """
+
+    return JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename)
+    )
+
+
+class DataModule(CachingDataModule):
+    """DRIVE dataset for Vessel Segmentation.
+
+    The DRIVE database has been established to enable comparative studies on
+    segmentation of blood vessels in retinal images.
+
+    * Reference: [DRIVE-2004]_
+    * Original resolution (height x width): 584 x 565
+    * Split reference: [DRIVE-2004]_
+    * Protocol ``default``:
+
+        * Training samples: 20 (including labels and masks)
+        * Test samples: 20 (including labels from annotator 1 and masks)
+
+    * Protocol ``second-annotator``:
+
+        * Test samples: 20 (including labels from annotator 2 and masks)
+
+    Parameters
+    ----------
+    split_filename
+        Name of the .json file containing the split to load.
+    """
+
+    def __init__(self, split_filename: str):
+        super().__init__(
+            database_split=make_split(split_filename),
+            raw_data_loader=RawDataLoader(),
+        )
diff --git a/src/mednet/libs/segmentation/config/data/drive/default.json b/src/mednet/libs/segmentation/config/data/drive/default.json
new file mode 100644
index 0000000000000000000000000000000000000000..6707e6edd93546915d7f9960f8ba85d3449a8d6f
--- /dev/null
+++ b/src/mednet/libs/segmentation/config/data/drive/default.json
@@ -0,0 +1,206 @@
+{
+ "train": [
+  [
+   "training/images/21_training.tif",
+   "training/1st_manual/21_manual1.gif",
+   "training/mask/21_training_mask.gif"
+  ],
+  [
+   "training/images/22_training.tif",
+   "training/1st_manual/22_manual1.gif",
+   "training/mask/22_training_mask.gif"
+  ],
+  [
+   "training/images/23_training.tif",
+   "training/1st_manual/23_manual1.gif",
+   "training/mask/23_training_mask.gif"
+  ],
+  [
+   "training/images/24_training.tif",
+   "training/1st_manual/24_manual1.gif",
+   "training/mask/24_training_mask.gif"
+  ],
+  [
+   "training/images/25_training.tif",
+   "training/1st_manual/25_manual1.gif",
+   "training/mask/25_training_mask.gif"
+  ],
+  [
+   "training/images/26_training.tif",
+   "training/1st_manual/26_manual1.gif",
+   "training/mask/26_training_mask.gif"
+  ],
+  [
+   "training/images/27_training.tif",
+   "training/1st_manual/27_manual1.gif",
+   "training/mask/27_training_mask.gif"
+  ],
+  [
+   "training/images/28_training.tif",
+   "training/1st_manual/28_manual1.gif",
+   "training/mask/28_training_mask.gif"
+  ],
+  [
+   "training/images/29_training.tif",
+   "training/1st_manual/29_manual1.gif",
+   "training/mask/29_training_mask.gif"
+  ],
+  [
+   "training/images/30_training.tif",
+   "training/1st_manual/30_manual1.gif",
+   "training/mask/30_training_mask.gif"
+  ],
+  [
+   "training/images/31_training.tif",
+   "training/1st_manual/31_manual1.gif",
+   "training/mask/31_training_mask.gif"
+  ],
+  [
+   "training/images/32_training.tif",
+   "training/1st_manual/32_manual1.gif",
+   "training/mask/32_training_mask.gif"
+  ],
+  [
+   "training/images/33_training.tif",
+   "training/1st_manual/33_manual1.gif",
+   "training/mask/33_training_mask.gif"
+  ],
+  [
+   "training/images/34_training.tif",
+   "training/1st_manual/34_manual1.gif",
+   "training/mask/34_training_mask.gif"
+  ],
+  [
+   "training/images/35_training.tif",
+   "training/1st_manual/35_manual1.gif",
+   "training/mask/35_training_mask.gif"
+  ],
+  [
+   "training/images/36_training.tif",
+   "training/1st_manual/36_manual1.gif",
+   "training/mask/36_training_mask.gif"
+  ],
+  [
+   "training/images/37_training.tif",
+   "training/1st_manual/37_manual1.gif",
+   "training/mask/37_training_mask.gif"
+  ],
+  [
+   "training/images/38_training.tif",
+   "training/1st_manual/38_manual1.gif",
+   "training/mask/38_training_mask.gif"
+  ],
+  [
+   "training/images/39_training.tif",
+   "training/1st_manual/39_manual1.gif",
+   "training/mask/39_training_mask.gif"
+  ],
+  [
+   "training/images/40_training.tif",
+   "training/1st_manual/40_manual1.gif",
+   "training/mask/40_training_mask.gif"
+  ]
+ ],
+ "test": [
+  [
+   "test/images/01_test.tif",
+   "test/1st_manual/01_manual1.gif",
+   "test/mask/01_test_mask.gif"
+  ],
+  [
+   "test/images/02_test.tif",
+   "test/1st_manual/02_manual1.gif",
+   "test/mask/02_test_mask.gif"
+  ],
+  [
+   "test/images/03_test.tif",
+   "test/1st_manual/03_manual1.gif",
+   "test/mask/03_test_mask.gif"
+  ],
+  [
+   "test/images/04_test.tif",
+   "test/1st_manual/04_manual1.gif",
+   "test/mask/04_test_mask.gif"
+  ],
+  [
+   "test/images/05_test.tif",
+   "test/1st_manual/05_manual1.gif",
+   "test/mask/05_test_mask.gif"
+  ],
+  [
+   "test/images/06_test.tif",
+   "test/1st_manual/06_manual1.gif",
+   "test/mask/06_test_mask.gif"
+  ],
+  [
+   "test/images/07_test.tif",
+   "test/1st_manual/07_manual1.gif",
+   "test/mask/07_test_mask.gif"
+  ],
+  [
+   "test/images/08_test.tif",
+   "test/1st_manual/08_manual1.gif",
+   "test/mask/08_test_mask.gif"
+  ],
+  [
+   "test/images/09_test.tif",
+   "test/1st_manual/09_manual1.gif",
+   "test/mask/09_test_mask.gif"
+  ],
+  [
+   "test/images/10_test.tif",
+   "test/1st_manual/10_manual1.gif",
+   "test/mask/10_test_mask.gif"
+  ],
+  [
+   "test/images/11_test.tif",
+   "test/1st_manual/11_manual1.gif",
+   "test/mask/11_test_mask.gif"
+  ],
+  [
+   "test/images/12_test.tif",
+   "test/1st_manual/12_manual1.gif",
+   "test/mask/12_test_mask.gif"
+  ],
+  [
+   "test/images/13_test.tif",
+   "test/1st_manual/13_manual1.gif",
+   "test/mask/13_test_mask.gif"
+  ],
+  [
+   "test/images/14_test.tif",
+   "test/1st_manual/14_manual1.gif",
+   "test/mask/14_test_mask.gif"
+  ],
+  [
+   "test/images/15_test.tif",
+   "test/1st_manual/15_manual1.gif",
+   "test/mask/15_test_mask.gif"
+  ],
+  [
+   "test/images/16_test.tif",
+   "test/1st_manual/16_manual1.gif",
+   "test/mask/16_test_mask.gif"
+  ],
+  [
+   "test/images/17_test.tif",
+   "test/1st_manual/17_manual1.gif",
+   "test/mask/17_test_mask.gif"
+  ],
+  [
+   "test/images/18_test.tif",
+   "test/1st_manual/18_manual1.gif",
+   "test/mask/18_test_mask.gif"
+  ],
+  [
+   "test/images/19_test.tif",
+   "test/1st_manual/19_manual1.gif",
+   "test/mask/19_test_mask.gif"
+  ],
+  [
+   "test/images/20_test.tif",
+   "test/1st_manual/20_manual1.gif",
+   "test/mask/20_test_mask.gif"
+  ]
+ ]
+}
diff --git a/src/mednet/libs/segmentation/config/data/drive/default.py b/src/mednet/libs/segmentation/config/data/drive/default.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fcaa0c67cd2889fedb2919802f888e04c9bdd11
--- /dev/null
+++ b/src/mednet/libs/segmentation/config/data/drive/default.py
@@ -0,0 +1,14 @@
+# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""DRIVE dataset for Vessel Segmentation (default protocol).
+
+* Split reference: [DRIVE-2004]_
+* This configuration resolution: 544 x 544 (center-crop)
+* See :py:mod:`deepdraw.data.drive` for dataset details
+* This dataset offers a second-annotator comparison for the test set only
+"""
+
+from mednet.libs.segmentation.config.data.drive.datamodule import DataModule
+
+datamodule = DataModule("default.json")
diff --git a/src/mednet/libs/segmentation/config/models/__init__.py b/src/mednet/libs/segmentation/config/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/mednet/libs/segmentation/config/models/lwnet.py b/src/mednet/libs/segmentation/config/models/lwnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cb08ea3b463ed79c9b2c8c4dbc670eae7b9fa64
--- /dev/null
+++ b/src/mednet/libs/segmentation/config/models/lwnet.py
@@ -0,0 +1,26 @@
+# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""Little W-Net for image segmentation.
+
+The Little W-Net architecture contains roughly around 70k parameters and
+closely matches (or outperforms) other more complex techniques.
+
+Reference: [GALDRAN-2020]_
+"""
+
+from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
+from mednet.libs.segmentation.models.lwnet import LittleWNet
+from torch.optim import Adam
+
+max_lr = 0.01  # start
+min_lr = 1e-08  # valley
+cycle = 50  # epochs for a complete scheduling cycle
+
+model = LittleWNet(
+    train_loss=MultiWeightedBCELogitsLoss(),
+    validation_loss=MultiWeightedBCELogitsLoss(),
+    optimizer_type=Adam,
+    optimizer_arguments=dict(lr=max_lr),
+    augmentation_transforms=[],
+)
diff --git a/src/mednet/libs/segmentation/data/__init__.py b/src/mednet/libs/segmentation/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/mednet/libs/segmentation/models/__init__.py b/src/mednet/libs/segmentation/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/mednet/libs/segmentation/models/losses.py b/src/mednet/libs/segmentation/models/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af10f0dd434132c992197ec310b8e33042ffd3f
--- /dev/null
+++ b/src/mednet/libs/segmentation/models/losses.py
@@ -0,0 +1,283 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Loss implementations."""
+
+import torch
+from torch.nn.modules.loss import _Loss
+
+
+class WeightedBCELogitsLoss(_Loss):
+    """Calculates sum of weighted cross entropy loss.
+
+    Implements Equation 1 in [MANINIS-2016]_.  The weight depends on the
+    current proportion between negatives and positives in the ground-
+    truth sample being analyzed.
+    """
+
+    def __init__(self):
+        super().__init__()
+
+    def forward(
+        self, sample: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
+    ) -> torch.Tensor:
+        """Forward pass.
+
+        Parameters
+        ----------
+        sample
+            Value produced by the model to be evaluated, with the shape ``[n, c,
+            h, w]``.
+
+        target
+            Ground-truth information with the shape ``[n, c, h, w]``.
+
+        mask
+            Mask to be use for specifying the region of interest where to
+            compute the loss, with the shape ``[n, c, h, w]``.
+
+        Returns
+        -------
+        loss
+            The average loss for all input data.
+        """
+
+        # calculates the proportion of negatives to the total number of pixels
+        # available in the masked region
+        valid = mask > 0.5
+        num_pos = target[valid].sum()
+        num_neg = valid.sum() - num_pos
+        pos_weight = num_neg / num_pos
+        return torch.nn.functional.binary_cross_entropy_with_logits(
+            sample[valid],
+            target[valid],
+            reduction="mean",
+            pos_weight=pos_weight,
+        )
+
+
+class SoftJaccardBCELogitsLoss(_Loss):
+    r"""Implement the generalized loss function of Equation (3) in.
+
+    [IGLOVIKOV-2018]_, with J being the Jaccard distance, and H, the Binary
+    Cross-Entropy Loss:
+
+    .. math::
+
+       L = \alpha H + (1-\alpha)(1-J)
+
+    Our implementation is based on :py:class:`torch.nn.BCEWithLogitsLoss`.
+
+    Parameters
+    ----------
+    alpha
+        Determines the weighting of J and H. Default: ``0.7``.
+    """
+
+    def __init__(self, alpha: float = 0.7):
+        super().__init__()
+        self.alpha = alpha
+
+    def forward(
+        self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
+    ) -> torch.Tensor:
+        """Forward pass.
+
+        Parameters
+        ----------
+        tensor
+            Value produced by the model to be evaluated, with the shape ``[n, c,
+            h, w]``.
+
+        target
+            Ground-truth information with the shape ``[n, c, h, w]``.
+
+        mask
+            Mask to be use for specifying the region of interest where to
+            compute the loss, with the shape ``[n, c, h, w]``.
+
+        Returns
+        -------
+        loss
+            Loss, in a single entry.
+        """
+
+        eps = 1e-8
+        valid = mask > 0.5
+        probabilities = torch.sigmoid(tensor[valid])
+        intersection = (probabilities * target[valid]).sum()
+        sums = probabilities.sum() + target[valid].sum()
+        j = intersection / (sums - intersection + eps)
+
+        # this implements the support for looking just into the RoI
+        h = torch.nn.functional.binary_cross_entropy_with_logits(
+            tensor[valid], target[valid], reduction="mean"
+        )
+        return (self.alpha * h) + ((1 - self.alpha) * (1 - j))
+
+
+class MultiWeightedBCELogitsLoss(WeightedBCELogitsLoss):
+    """Weighted Binary Cross-Entropy Loss for multi-layered inputs (e.g. for
+    Holistically-Nested Edge Detection in [XIE-2015]_).
+    """
+
+    def __init__(self):
+        super().__init__()
+
+    def forward(
+        self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
+    ) -> torch.Tensor:
+        """Forward pass.
+
+        Parameters
+        ----------
+        tensor
+            Value produced by the model to be evaluated, with the shape ``[L,
+            n, c, h, w]``.
+
+        target
+            Ground-truth information with the shape ``[n, c, h, w]``.
+
+        mask
+            Mask to be use for specifying the region of interest where to
+            compute the loss, with the shape ``[n, c, h, w]``.
+
+        Returns
+        -------
+            The average loss for all input data.
+        """
+
+        return torch.cat(
+            [
+                super(MultiWeightedBCELogitsLoss, self)
+                .forward(i, target, mask)
+                .unsqueeze(0)
+                for i in tensor
+            ]
+        ).mean()
+
+
+class MultiSoftJaccardBCELogitsLoss(SoftJaccardBCELogitsLoss):
+    """Implements  Equation 3 in [IGLOVIKOV-2018]_ for the multi-output
+    networks such as HED or Little W-Net.
+
+    Parameters
+    ----------
+    alpha : float
+        Determines the weighting of SoftJaccard and BCE. Default: ``0.3``.
+    """
+
+    def __init__(self, alpha: float = 0.7):
+        super().__init__(alpha=alpha)
+
+    def forward(
+        self, tensor: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
+    ) -> torch.Tensor:
+        """Forward pass.
+
+        Parameters
+        ----------
+        tensor
+            Value produced by the model to be evaluated, with the shape ``[L,
+            n, c, h, w]``.
+
+        target
+            Ground-truth information with the shape ``[n, c, h, w]``.
+
+        mask
+            Mask to be use for specifying the region of interest where to
+            compute the loss, with the shape ``[n, c, h, w]``.
+
+        Returns
+        -------
+            The average loss for all input data.
+        """
+
+        return torch.cat(
+            [
+                super(MultiSoftJaccardBCELogitsLoss, self)
+                .forward(i, target, mask)
+                .unsqueeze(0)
+                for i in tensor
+            ]
+        ).mean()
+
+
+class MixJacLoss(_Loss):
+    """
+    Parameters
+    ----------
+    lambda_u
+        Determines the weighting of SoftJaccard and BCE.
+
+    jacalpha
+        Determines the weighting of J and H.
+
+    size_average
+        By default, the losses are averaged over each loss element in the batch. Note that for
+        some losses, there are multiple elements per sample. If the field :attr:`size_average`
+        is set to ``False``, the losses are instead summed for each minibatch. Ignored
+        when :attr:`reduce` is ``False``. Default: ``True``.
+
+    reduce
+        By default, the losses are averaged or summed over observations for each minibatch depending
+        on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+        batch element instead and ignores :attr:`size_average`. Default: ``True``.
+
+    reduction
+        Specifies the reduction to apply to the output:
+        ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+        ``'mean'``: the sum of the output will be divided by the number of
+        elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+        and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+        specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``.
+    """
+
+    def __init__(
+        self,
+        lambda_u: int = 100,
+        jacalpha=0.7,
+        size_average=None,
+        reduce=None,
+        reduction="mean",
+    ):
+        super().__init__(size_average, reduce, reduction)
+        self.lambda_u = lambda_u
+        self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha)
+        self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()
+
+    def forward(
+        self,
+        tensor: torch.Tensor,
+        target: torch.Tensor,
+        unlabeled_tensor: torch.Tensor,
+        unlabeled_target: torch.Tensor,
+        ramp_up_factor: float,
+    ) -> tuple:
+        """Forward pass.
+
+        Parameters
+        ----------
+        tensor
+            Value produced by the model to be evaluated, with the shape ``[L,
+            n, c, h, w]``.
+
+        target
+            Ground-truth information with the shape ``[n, c, h, w]``.
+
+        unlabeled_tensor
+
+        unlabeled_target
+
+        ramp_up_factor
+
+        Returns
+        -------
+        list
+        """
+        ll = self.labeled_loss(tensor, target)
+        ul = self.unlabeled_loss(unlabeled_tensor, unlabeled_target)
+
+        loss = ll + self.lambda_u * ramp_up_factor * ul
+        return loss, ll, ul
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cad67c925c5e9ff0a38527ba328b36242b62ab3
--- /dev/null
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -0,0 +1,351 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Little W-Net.
+
+Code was originally developed by Adrian Galdran
+(https://github.com/agaldran/lwnet), loosely inspired on
+https://github.com/jvanvugt/pytorch-unet
+
+It is based on two simple U-Nets with 3 layers concatenated to each other.  The
+first U-Net produces a segmentation map that is used by the second to better
+guide segmentation.
+
+Reference: [GALDRAN-2020]_
+"""
+
+import typing
+
+import lightning.pytorch as pl
+import torch
+import torch.nn
+import torchvision.transforms
+from mednet.libs.common.data.typing import TransformSequence
+from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
+from torchvision.transforms.v2 import CenterCrop
+
+from .separate import separate
+
+
+def _conv1x1(in_planes, out_planes, stride=1):
+    return torch.nn.Conv2d(
+        in_planes, out_planes, kernel_size=1, stride=stride, bias=False
+    )
+
+
+class ConvBlock(torch.nn.Module):
+    def __init__(self, in_c, out_c, k_sz=3, shortcut=False, pool=True):
+        super().__init__()
+        if shortcut is True:
+            self.shortcut = torch.nn.Sequential(
+                _conv1x1(in_c, out_c), torch.nn.BatchNorm2d(out_c)
+            )
+        else:
+            self.shortcut = False
+        pad = (k_sz - 1) // 2
+
+        block = []
+        if pool:
+            self.pool = torch.nn.MaxPool2d(kernel_size=2)
+        else:
+            self.pool = False
+
+        block.append(
+            torch.nn.Conv2d(in_c, out_c, kernel_size=k_sz, padding=pad)
+        )
+        block.append(torch.nn.ReLU())
+        block.append(torch.nn.BatchNorm2d(out_c))
+
+        block.append(
+            torch.nn.Conv2d(out_c, out_c, kernel_size=k_sz, padding=pad)
+        )
+        block.append(torch.nn.ReLU())
+        block.append(torch.nn.BatchNorm2d(out_c))
+
+        self.block = torch.nn.Sequential(*block)
+
+    def forward(self, x):
+        if self.pool:
+            x = self.pool(x)
+        out = self.block(x)
+        if self.shortcut:
+            return out + self.shortcut(x)
+        return out
+
+
+class UpsampleBlock(torch.nn.Module):
+    def __init__(self, in_c, out_c, up_mode="transp_conv"):
+        super().__init__()
+        block = []
+        if up_mode == "transp_conv":
+            block.append(
+                torch.nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2)
+            )
+        elif up_mode == "up_conv":
+            block.append(
+                torch.nn.Upsample(
+                    mode="bilinear", scale_factor=2, align_corners=False
+                )
+            )
+            block.append(torch.nn.Conv2d(in_c, out_c, kernel_size=1))
+        else:
+            raise Exception("Upsampling mode not supported")
+
+        self.block = torch.nn.Sequential(*block)
+
+    def forward(self, x):
+        return self.block(x)
+
+
+class ConvBridgeBlock(torch.nn.Module):
+    def __init__(self, channels, k_sz=3):
+        super().__init__()
+        pad = (k_sz - 1) // 2
+        block = []
+
+        block.append(
+            torch.nn.Conv2d(channels, channels, kernel_size=k_sz, padding=pad)
+        )
+        block.append(torch.nn.ReLU())
+        block.append(torch.nn.BatchNorm2d(channels))
+
+        self.block = torch.nn.Sequential(*block)
+
+    def forward(self, x):
+        return self.block(x)
+
+
+class UpConvBlock(torch.nn.Module):
+    def __init__(
+        self,
+        in_c,
+        out_c,
+        k_sz=3,
+        up_mode="up_conv",
+        conv_bridge=False,
+        shortcut=False,
+    ):
+        super().__init__()
+        self.conv_bridge = conv_bridge
+
+        self.up_layer = UpsampleBlock(in_c, out_c, up_mode=up_mode)
+        self.conv_layer = ConvBlock(
+            2 * out_c, out_c, k_sz=k_sz, shortcut=shortcut, pool=False
+        )
+        if self.conv_bridge:
+            self.conv_bridge_layer = ConvBridgeBlock(out_c, k_sz=k_sz)
+
+    def forward(self, x, skip):
+        up = self.up_layer(x)
+        if self.conv_bridge:
+            out = torch.cat([up, self.conv_bridge_layer(skip)], dim=1)
+        else:
+            out = torch.cat([up, skip], dim=1)
+        return self.conv_layer(out)
+
+
+class LittleUNet(torch.nn.Module):
+    """Little U-Net model.
+
+    Parameters
+    ----------
+    in_c
+
+    n_classes
+        Number of outputs (classes) for this model.
+
+    layers
+
+    k_sz
+
+    up_mode
+
+    conv_bridge
+
+    shortcut
+    """
+
+    def __init__(
+        self,
+        in_c,
+        n_classes,
+        layers,
+        k_sz=3,
+        up_mode="transp_conv",
+        conv_bridge=True,
+        shortcut=True,
+    ):
+        super().__init__()
+        self.n_classes = n_classes
+        self.first = ConvBlock(
+            in_c=in_c, out_c=layers[0], k_sz=k_sz, shortcut=shortcut, pool=False
+        )
+
+        self.down_path = torch.nn.ModuleList()
+        for i in range(len(layers) - 1):
+            block = ConvBlock(
+                in_c=layers[i],
+                out_c=layers[i + 1],
+                k_sz=k_sz,
+                shortcut=shortcut,
+                pool=True,
+            )
+            self.down_path.append(block)
+
+        self.up_path = torch.nn.ModuleList()
+        reversed_layers = list(reversed(layers))
+        for i in range(len(layers) - 1):
+            block = UpConvBlock(
+                in_c=reversed_layers[i],
+                out_c=reversed_layers[i + 1],
+                k_sz=k_sz,
+                up_mode=up_mode,
+                conv_bridge=conv_bridge,
+                shortcut=shortcut,
+            )
+            self.up_path.append(block)
+
+        # init, shamelessly lifted from torchvision/models/resnet.py
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                torch.nn.init.kaiming_normal_(
+                    m.weight, mode="fan_out", nonlinearity="relu"
+                )
+            elif isinstance(m, torch.nn.BatchNorm2d | torch.nn.GroupNorm):
+                torch.nn.init.constant_(m.weight, 1)
+                torch.nn.init.constant_(m.bias, 0)
+
+        self.final = torch.nn.Conv2d(layers[0], n_classes, kernel_size=1)
+
+    def forward(self, x):
+        x = self.first(x)
+        down_activations = []
+        for i, down in enumerate(self.down_path):
+            down_activations.append(x)
+            x = down(x)
+        down_activations.reverse()
+        for i, up in enumerate(self.up_path):
+            x = up(x, down_activations[i])
+        return self.final(x)
+
+
+class LittleWNet(pl.LightningModule):
+    """Little W-Net model, concatenating two Little U-Net models.
+
+    Parameters
+    ----------
+    train_loss
+        The loss to be used during the training.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+    validation_loss
+        The loss to be used for validation (may be different from the training
+        loss).  If extra-validation sets are provided, the same loss will be
+        used throughout.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+    optimizer_type
+        The type of optimizer to use for training.
+    optimizer_arguments
+        Arguments to the optimizer after ``params``.
+    augmentation_transforms
+        An optional sequence of torch modules containing transforms to be
+        applied on the input **before** it is fed into the network.
+    num_classes
+        Number of outputs (classes) for this model.
+    """
+
+    def __init__(
+        self,
+        train_loss=MultiWeightedBCELogitsLoss(),
+        validation_loss=MultiWeightedBCELogitsLoss(),
+        optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
+        optimizer_arguments: dict[str, typing.Any] = {},
+        augmentation_transforms: TransformSequence = [],
+        num_classes: int = 1,
+    ):
+        super().__init__()
+
+        self.name = "lwnet"
+        self.num_classes = num_classes
+
+        self.model_transforms = [CenterCrop(size=(544, 544))]
+
+        self._train_loss = train_loss
+        self._validation_loss = (
+            validation_loss if validation_loss is not None else train_loss
+        )
+        self._optimizer_type = optimizer_type
+        self._optimizer_arguments = optimizer_arguments
+
+        self._augmentation_transforms = torchvision.transforms.Compose(
+            augmentation_transforms
+        )
+
+        self.unet1 = LittleUNet(
+            in_c=3,
+            n_classes=self.num_classes,
+            layers=(8, 16, 32),
+            conv_bridge=True,
+            shortcut=True,
+        )
+        self.unet2 = LittleUNet(
+            in_c=3 + self.num_classes,
+            n_classes=self.num_classes,
+            layers=(8, 16, 32),
+            conv_bridge=True,
+            shortcut=True,
+        )
+
+    def forward(self, x):
+        x1 = self.unet1(x)
+        x2 = self.unet2(torch.cat([x, x1], dim=1))
+
+        return x1, x2
+
+    def training_step(self, batch, batch_idx):
+        images = batch[0]
+        ground_truths = batch[1]["label"]
+        masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3]
+
+        outputs = self(self._augmentation_transforms(images))
+
+        return self._train_loss(outputs, ground_truths, masks)
+
+    def validation_step(self, batch, batch_idx):
+        images = batch[0]
+        ground_truths = batch[1]["label"]
+        masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3]
+
+        outputs = self(images)
+        return self._validation_loss(outputs, ground_truths, masks)
+
+    def predict_step(self, batch, batch_idx, dataloader_idx=0):
+        outputs = self(batch[0])
+        probabilities = torch.sigmoid(outputs)
+        return separate((probabilities, batch[1]))
+
+    def configure_optimizers(self):
+        return self._optimizer_type(
+            self.parameters(), **self._optimizer_arguments
+        )
+
+    """def configure_optimizers(self):
+        optimizer = getattr(
+            self, 'optimizer', Adam(self.parameters(), lr=1e-3)
+        )
+        if optimizer is None:
+            raise ValueError("Optimizer not found. Please provide an optimizer.")
+
+        scheduler = getattr(self, 'scheduler', None)
+        if scheduler is None:
+            return {'optimizer': optimizer}
+        else:
+            return {'optimizer': optimizer, 'lr_scheduler': scheduler}"""
diff --git a/src/mednet/libs/segmentation/models/separate.py b/src/mednet/libs/segmentation/models/separate.py
new file mode 100644
index 0000000000000000000000000000000000000000..8abd2329030c055606056741c5642e5ad29256dd
--- /dev/null
+++ b/src/mednet/libs/segmentation/models/separate.py
@@ -0,0 +1,61 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""Contains the inverse :py:func:`torch.utils.data.default_collate`."""
+
+import typing
+
+import torch
+from mednet.libs.common.data.typing import Sample
+
+from .typing import BinaryPrediction, MultiClassPrediction
+
+
+def _as_predictions(
+    samples: typing.Iterable[Sample],
+) -> list[BinaryPrediction | MultiClassPrediction]:
+    """Take a list of separated batch predictions and transforms it into a list
+    of formal predictions.
+
+    Parameters
+    ----------
+    samples
+        A sequence of samples as returned by :py:func:`separate`.
+
+    Returns
+    -------
+    list[BinaryPrediction | MultiClassPrediction]
+        A list of typed predictions that can be saved to disk.
+    """
+    return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples]
+
+
+def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]:
+    """Separate a collated batch, reconstituting its samples.
+
+    This function implements the inverse of
+    :py:func:`torch.utils.data.default_collate`, and can separate, into
+    samples, batches of data with different attributes.  It follows the inverse
+    path of that function, and implements the following separation algorithms:
+
+    * :class:`torch.Tensor` -> :class:`torch.Tensor` (with a removed outer
+      dimension, via :py:func:`torch.flatten`)
+    * ``typing.Mapping[K, V[]]`` -> ``[dict[K, V_1], dict[K, V_2], ...]``
+
+    Parameters
+    ----------
+    batch
+        A batch, as output by torch model forwarding.
+
+    Returns
+    -------
+        A list of predictions that contains the predictions and associated metadata
+        for each processed sample.
+    """
+
+    # as of now, this is really simple - to be made more complex upon need.
+    metadata = [
+        {key: value[i] for key, value in batch[1].items()}
+        for i in range(len(batch[0]))
+    ]
+    return _as_predictions(zip(torch.flatten(batch[0]), metadata))
diff --git a/src/mednet/libs/segmentation/models/typing.py b/src/mednet/libs/segmentation/models/typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eb9017c1e7bbbf860cf7be67622cdfe5db513df
--- /dev/null
+++ b/src/mednet/libs/segmentation/models/typing.py
@@ -0,0 +1,27 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""Defines most common types used in code."""
+
+import typing
+
+Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any]
+"""Definition of a lightning checkpoint."""
+
+BinaryPrediction: typing.TypeAlias = tuple[str, int, float]
+"""The sample name, the target, and the predicted value."""
+
+MultiClassPrediction: typing.TypeAlias = tuple[
+    str, typing.Sequence[int], typing.Sequence[float]
+]
+"""The sample name, the target, and the predicted value."""
+
+BinaryPredictionSplit: typing.TypeAlias = typing.Mapping[
+    str, typing.Sequence[BinaryPrediction]
+]
+"""A series of predictions for different database splits."""
+
+MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[
+    str, typing.Sequence[MultiClassPrediction]
+]
+"""A series of predictions for different database splits."""
diff --git a/src/mednet/libs/segmentation/scripts/__init__.py b/src/mednet/libs/segmentation/scripts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/mednet/libs/segmentation/scripts/cli.py b/src/mednet/libs/segmentation/scripts/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..add5fdac1b11a770b04c21c6f2ea20ae13abf881
--- /dev/null
+++ b/src/mednet/libs/segmentation/scripts/cli.py
@@ -0,0 +1,41 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import click
+from clapper.click import AliasedGroup
+
+from . import (
+    # analyze,
+    # compare,
+    # config,
+    # dataset,
+    # evaluate,
+    # experiment,
+    # mkmask,
+    # predict,
+    # significance,
+    train,
+)
+
+
+@click.group(
+    cls=AliasedGroup,
+    context_settings=dict(help_option_names=["-?", "-h", "--help"]),
+)
+def segmentation():
+    """Binary Segmentation Benchmark."""
+    pass
+
+
+# segmentation.add_command(analyze.analyze)
+# segmentation.add_command(compare.compare)
+# segmentation.add_command(config.config)
+# segmentation.add_command(dataset.dataset)
+# segmentation.add_command(evaluate.evaluate)
+# segmentation.add_command(experiment.experiment)
+# segmentation.add_command(mkmask.mkmask)
+# segmentation.add_command(predict.predict)
+# segmentation.add_command(significance.significance)
+segmentation.add_command(train.train)
+# segmentation.add_command(train_analysis.train_analysis)
diff --git a/src/mednet/libs/segmentation/scripts/click.py b/src/mednet/libs/segmentation/scripts/click.py
new file mode 100644
index 0000000000000000000000000000000000000000..b028e6a737345fbfc1c61474e180e84e50679be9
--- /dev/null
+++ b/src/mednet/libs/segmentation/scripts/click.py
@@ -0,0 +1,28 @@
+# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import click
+from clapper.click import ConfigCommand as _BaseConfigCommand
+
+
+class ConfigCommand(_BaseConfigCommand):
+    """A click command-class that has the properties of :py:class:`clapper.click.ConfigCommand` and adds verbatim epilog formatting."""
+
+    def format_epilog(
+        self, _: click.core.Context, formatter: click.formatting.HelpFormatter
+    ) -> None:
+        """Format the command epilog during --help.
+
+        Parameters
+        ----------
+            _
+                The current parsing context.
+            formatter
+                The formatter to use for printing text.
+        """
+
+        if self.epilog:
+            formatter.write_paragraph()
+            for line in self.epilog.split("\n"):
+                formatter.write(line + "\n")
diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..0796c3d77df3d28fc491de60545a056ebde9cf9e
--- /dev/null
+++ b/src/mednet/libs/segmentation/scripts/train.py
@@ -0,0 +1,376 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import functools
+import pathlib
+import typing
+from pathlib import Path
+
+import click
+from clapper.click import ResourceOption, verbosity_option
+from clapper.logging import setup
+
+from .click import ConfigCommand
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+def reusable_options(f):
+    """The options that can be re-used by top-level scripts (i.e.
+    ``experiment``).
+
+    This decorator equips the target function ``f`` with all (reusable)
+    ``train`` script options.
+
+    Parameters
+    ----------
+    f
+        The target function to equip with options.  This function must have
+        parameters that accept such options.
+
+    Returns
+    -------
+        The decorated version of function ``f``
+    """  # noqa D401
+
+    @click.option(
+        "--output-folder",
+        "-o",
+        help="Directory in which to store results (created if does not exist)",
+        required=True,
+        type=click.Path(
+            file_okay=False,
+            dir_okay=True,
+            writable=True,
+            path_type=pathlib.Path,
+        ),
+        default="results",
+        cls=ResourceOption,
+    )
+    @click.option(
+        "--model",
+        "-m",
+        help="A lightning module instance implementing the network to be trained",
+        required=True,
+        cls=ResourceOption,
+    )
+    @click.option(
+        "--datamodule",
+        "-d",
+        help="A lightning DataModule containing the training and validation sets.",
+        required=True,
+        cls=ResourceOption,
+    )
+    @click.option(
+        "--batch-size",
+        "-b",
+        help="Number of samples in every batch (this parameter affects "
+        "memory requirements for the network).  If the number of samples in "
+        "the batch is larger than the total number of samples available for "
+        "training, this value is truncated.  If this number is smaller, then "
+        "batches of the specified size are created and fed to the network "
+        "until there are no more new samples to feed (epoch is finished).  "
+        "If the total number of training samples is not a multiple of the "
+        "batch-size, the last batch will be smaller than the first, unless "
+        "--drop-incomplete-batch is set, in which case this batch is not used.",
+        required=True,
+        show_default=True,
+        default=1,
+        type=click.IntRange(min=1),
+        cls=ResourceOption,
+    )
+    @click.option(
+        "--batch-chunk-count",
+        "-c",
+        help="Number of chunks in every batch (this parameter affects "
+        "memory requirements for the network). The number of samples "
+        "loaded for every iteration will be batch-size/batch-chunk-count. "
+        "batch-size needs to be divisible by batch-chunk-count, otherwise an "
+        "error will be raised. This parameter is used to reduce the number of "
+        "samples loaded in each iteration, in order to reduce the memory usage "
+        "in exchange for processing time (more iterations).  This is especially "
+        "interesting when one is training on GPUs with limited RAM. The "
+        "default of 1 forces the whole batch to be processed at once.  Otherwise "
+        "the batch is broken into batch-chunk-count pieces, and gradients are "
+        "accumulated to complete each batch.",
+        required=True,
+        show_default=True,
+        default=1,
+        type=click.IntRange(min=1),
+        cls=ResourceOption,
+    )
+    @click.option(
+        "--drop-incomplete-batch/--no-drop-incomplete-batch",
+        "-D",
+        help="If set, the last batch in an epoch will be dropped if "
+        "incomplete.  If you set this option, you should also consider "
+        "increasing the total number of epochs of training, as the total number "
+        "of training steps may be reduced.",
+        required=True,
+        show_default=True,
+        default=False,
+        cls=ResourceOption,
+    )
+    @click.option(
+        "--epochs",
+        "-e",
+        help="""Number of epochs (complete training set passes) to train for.
+        If continuing from a saved checkpoint, ensure to provide a greater
+        number of epochs than was saved in the checkpoint to be loaded.""",
+        show_default=True,
+        required=True,
+        default=1000,
+        type=click.IntRange(min=1),
+        cls=ResourceOption,
+    )
+    @click.option(
+        "--validation-period",
+        "-p",
+        help="""Number of epochs after which validation happens.  By default,
+        we run validation after every training epoch (period=1).  You can
+        change this to make validation more sparse, by increasing the
+        validation period. Notice that this affects checkpoint saving.  While
+        checkpoints are created after every training step (the last training
+        step always triggers the overriding of latest checkpoint), and
+        this process is independent of validation runs, evaluation of the
+        'best' model obtained so far based on those will be influenced by this
+        setting.""",
+        show_default=True,
+        required=True,
+        default=1,
+        type=click.IntRange(min=1),
+        cls=ResourceOption,
+    )
+    @click.option(
+        "--device",
+        "-x",
+        help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
+        show_default=True,
+        required=True,
+        default="cpu",
+        cls=ResourceOption,
+    )
+    @click.option(
+        "--cache-samples/--no-cache-samples",
+        help="If set to True, loads the sample into memory, "
+        "otherwise loads them at runtime.",
+        required=True,
+        show_default=True,
+        default=False,
+        cls=ResourceOption,
+    )
+    @click.option(
+        "--seed",
+        "-s",
+        help="Seed to use for the random number generator",
+        show_default=True,
+        required=False,
+        default=42,
+        type=click.IntRange(min=0),
+        cls=ResourceOption,
+    )
+    @click.option(
+        "--parallel",
+        "-P",
+        help="""Use multiprocessing for data loading: if set to -1 (default),
+        disables multiprocessing data loading.  Set to 0 to enable as many data
+        loading instances as processing cores available in the system.  Set to
+        >= 1 to enable that many multiprocessing instances for data
+        loading.""",
+        type=click.IntRange(min=-1),
+        show_default=True,
+        required=True,
+        default=-1,
+        cls=ResourceOption,
+    )
+    @click.option(
+        "--monitoring-interval",
+        "-I",
+        help="""Time between checks for the use of resources during each training
+        epoch, in seconds.  An interval of 5 seconds, for example, will lead to
+        CPU and GPU resources being probed every 5 seconds during each training
+        epoch. Values registered in the training logs correspond to averages
+        (or maxima) observed through possibly many probes in each epoch.
+        Notice that setting a very small value may cause the probing process to
+        become extremely busy, potentially biasing the overall perception of
+        resource usage.""",
+        type=click.FloatRange(min=0.1),
+        show_default=True,
+        required=True,
+        default=5.0,
+        cls=ResourceOption,
+    )
+    @click.option(
+        "--balance-classes/--no-balance-classes",
+        "-B/-N",
+        help="""If set, balances weights of the random sampler during
+        training so that samples from all sample classes are picked
+        equitably.""",
+        required=True,
+        show_default=True,
+        default=True,
+        cls=ResourceOption,
+    )
+    @functools.wraps(f)
+    def wrapper_reusable_options(*args, **kwargs):
+        return f(*args, **kwargs)
+
+    return wrapper_reusable_options
+
+
+@click.command(
+    entry_point_group="mednet.libs.segmentation.config",
+    cls=ConfigCommand,
+    epilog="""Examples:
+
+1. Train a pasa model with the montgomery dataset, on a GPU (``cuda:0``):
+
+   .. code:: sh
+
+      deepdraw train -vv pasa montgomery --batch-size=4 --device="cuda:0"
+""",
+)
+@reusable_options
+@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
+def train(
+    model,
+    output_folder,
+    epochs,
+    batch_size,
+    batch_chunk_count,
+    drop_incomplete_batch,
+    datamodule,
+    validation_period,
+    device,
+    cache_samples,
+    seed,
+    parallel,
+    monitoring_interval,
+    balance_classes,
+    **_,
+) -> None:  # numpydoc ignore=PR01
+    """Train an CNN to perform image classification.
+
+    Training is performed for a configurable number of epochs, and
+    generates checkpoints.  Checkpoints are model files with a .ckpt
+    extension that are used in subsequent tasks or from which training
+    can be resumed.
+    """
+
+    import torch
+    from lightning.pytorch import seed_everything
+    from mednet.libs.common.engine.device import DeviceManager
+    from mednet.libs.common.engine.trainer import run
+    from mednet.libs.common.utils.checkpointer import (
+        get_checkpoint_to_resume_training,
+    )
+
+    from .utils import (
+        device_properties,
+        execution_metadata,
+        model_summary,
+        save_json_with_backup,
+    )
+
+    checkpoint_file = None
+    if Path.is_dir(output_folder):
+        try:
+            checkpoint_file = get_checkpoint_to_resume_training(output_folder)
+        except FileNotFoundError:
+            logger.info(
+                f"Folder {output_folder} already exists, but I did not"
+                f" find any usable checkpoint file to resume training"
+                f" from. Starting from scratch..."
+            )
+
+    seed_everything(seed)
+
+    # reset datamodule with user configurable options
+    datamodule.set_chunk_size(batch_size, batch_chunk_count)
+    datamodule.drop_incomplete_batch = drop_incomplete_batch
+    datamodule.cache_samples = cache_samples
+    datamodule.parallel = parallel
+    datamodule.model_transforms = model.model_transforms
+
+    datamodule.prepare_data()
+    datamodule.setup(stage="fit")
+
+    # If asked, rebalances the loss criterion based on the relative proportion
+    # of class examples available in the training set.  Also affects the
+    # validation loss if a validation set is available on the DataModule.
+    if balance_classes:
+        logger.info("Applying DataModule train sampler balancing...")
+        datamodule.balance_sampler_by_class = True
+        # logger.info("Applying train/valid loss balancing...")
+        # model.balance_losses_by_class(datamodule)
+    else:
+        logger.info(
+            "Skipping sample class/dataset ownership balancing on user request"
+        )
+
+    logger.info(f"Training for at most {epochs} epochs.")
+
+    arguments = {}
+    arguments["max_epoch"] = epochs
+    arguments["epoch"] = 0
+
+    if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"):
+        # Sets the model normalizer with the unaugmented-train-subset if we are
+        # starting from scratch and/or the model does not contain its own
+        # checkpoint loading strategy (e.g. a pytorch stock checkpoint). This
+        # call may be a NOOP, if the model comes from outside this framework,
+        # and expects different weights for the normalisation layer.
+        if hasattr(model, "set_normalizer"):
+            model.set_normalizer(datamodule.unshuffled_train_dataloader())
+        else:
+            logger.warning(
+                f"Model {model.name} has no `set_normalizer` method. "
+                "Skipping normalization setup (unsupported external model)."
+            )
+    else:
+        # Normalizer will be loaded during model.on_load_checkpoint
+        checkpoint = torch.load(checkpoint_file)
+        start_epoch = checkpoint["epoch"]
+        logger.info(
+            f"Resuming from epoch {start_epoch} "
+            f"(checkpoint file: `{str(checkpoint_file)}`)..."
+        )
+
+    device_manager = DeviceManager(device)
+
+    # stores all information we can think of, to reproduce this later
+    json_data: dict[str, typing.Any] = execution_metadata()
+    json_data.update(device_properties(device_manager.device_type))
+    json_data.update(
+        dict(
+            database_name=datamodule.database_name,
+            split_name=datamodule.split_name,
+            epochs=epochs,
+            batch_size=batch_size,
+            batch_chunk_count=batch_chunk_count,
+            drop_incomplete_batch=drop_incomplete_batch,
+            validation_period=validation_period,
+            cache_samples=cache_samples,
+            seed=seed,
+            parallel=parallel,
+            monitoring_interval=monitoring_interval,
+            balance_classes=balance_classes,
+            model_name=model.name,
+        ),
+    )
+    json_data.update(model_summary(model))
+    json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
+    save_json_with_backup(output_folder / "meta.json", json_data)
+
+    run(
+        model=model,
+        datamodule=datamodule,
+        validation_period=validation_period,
+        device_manager=device_manager,
+        max_epochs=epochs,
+        output_folder=output_folder,
+        monitoring_interval=monitoring_interval,
+        batch_chunk_count=batch_chunk_count,
+        checkpoint=checkpoint_file,
+    )
diff --git a/src/mednet/libs/segmentation/scripts/utils.py b/src/mednet/libs/segmentation/scripts/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..890d1c5437d80c71a9a6956fa9fdaf7e259f4259
--- /dev/null
+++ b/src/mednet/libs/segmentation/scripts/utils.py
@@ -0,0 +1,192 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""Utilities for command-line scripts."""
+
+import json
+import logging
+import pathlib
+import re
+import shutil
+
+import lightning.pytorch
+import lightning.pytorch.callbacks
+import torch.nn
+from mednet.libs.common.engine.device import SupportedPytorchDevice
+
+logger = logging.getLogger("mednet")
+
+
+def model_summary(
+    model: torch.nn.Module,
+) -> dict[str, int | list[tuple[str, str, int]]]:
+    """Save a little summary of the model in a txt file.
+
+    Parameters
+    ----------
+    model
+        Instance of the model for which to save the summary.
+
+    Returns
+    -------
+    tuple[lightning.pytorch.callbacks.ModelSummary, int]
+        A tuple with the model summary in a text format and number of parameters of the model.
+    """
+
+    s = lightning.pytorch.utilities.model_summary.ModelSummary(  # type: ignore
+        model,
+    )
+
+    return dict(
+        model_summary=list(zip(s.layer_names, s.layer_types, s.param_nums)),
+        model_size=s.total_parameters,
+    )
+
+
+def device_properties(
+    device_type: SupportedPytorchDevice,
+) -> dict[str, int | float | str]:
+    """Generate information concerning hardware properties.
+
+    Parameters
+    ----------
+    device_type
+        The type of compute device we are using.
+
+    Returns
+    -------
+        Static properties of the current machine.
+    """
+
+    from mednet.libs.common.utils.resources import (
+        cpu_constants,
+        cuda_constants,
+        mps_constants,
+    )
+
+    retval: dict[str, int | float | str] = {}
+    retval.update(cpu_constants())
+
+    match device_type:
+        case "cpu":
+            pass
+        case "cuda":
+            results = cuda_constants()
+            if results is not None:
+                retval.update(results)
+        case "mps":
+            results = mps_constants()
+            if results is not None:
+                retval.update(results)
+        case _:
+            pass
+
+    return retval
+
+
+def execution_metadata() -> dict[str, int | float | str | dict[str, str]]:
+    """Produce metadata concerning the running script, in the form of a
+    dictionary.
+
+    This function returns potentially useful metadata concerning program
+    execution.  It contains a certain number of preset variables.
+
+    Returns
+    -------
+        A dictionary that contains the following fields:
+
+        * ``package-name``: current package name (e.g. ``mednet``)
+        * ``package-version``: current package version (e.g. ``1.0.0b0``)
+        * ``datetime``: date and time in ISO8601 format (e.g. ``2024-02-23T18:38:09+01:00``)
+        * ``user``: username (e.g. ``johndoe``)
+        * ``conda-env``: if set, the name of the current conda environment
+        * ``path``: current path when executing the command
+        * ``command-line``: the command-line that is being run
+        * ``hostname``: machine hostname (e.g. ``localhost``)
+        * ``platform``: machine platform (e.g. ``darwin``)
+    """
+
+    import importlib.metadata
+    import importlib.util
+    import os
+    import sys
+
+    args = []
+    for k in sys.argv:
+        if " " in k:
+            args.append(f"'{k}'")
+        else:
+            args.append(k)
+
+    # current date time, in ISO8610 format
+    datetime = __import__("datetime").datetime.now().astimezone().isoformat()
+
+    # collects dependence information
+    package_name = __package__.split(".")[0]
+    requires = importlib.metadata.requires(package_name) or []
+    dependence_names = [re.split(r"(\=|~|!|>|<|;|\s)+", k)[0] for k in requires]
+    dependencies = {
+        k: importlib.metadata.version(k)  # version number as str
+        for k in dependence_names
+        if importlib.util.find_spec(k) is not None  # if is installed
+    }
+
+    # checks if the current version corresponds to a dirty (uncommitted) change
+    # set, issues a warning to the user
+    current_version = importlib.metadata.version(package_name)
+    try:
+        import versioningit
+
+        actual_version = versioningit.get_version(".", config={})
+        if current_version != actual_version:
+            logger.warning(
+                f"Version mismatch between current version set "
+                f"({current_version}) and actual version returned by "
+                f"versioningit ({actual_version}).  This typically happens "
+                f"when you commit changes locally and do not re-install the "
+                f"package. Run `pip install -e .` or equivalent to fix this.",
+            )
+    except Exception as e:
+        # not in a git repo?
+        logger.debug(f"Error {e}")
+        pass
+
+    return {
+        "datetime": datetime,
+        "package-name": __package__.split(".")[0],
+        "package-version": current_version,
+        "dependencies": dependencies,
+        "user": __import__("getpass").getuser(),
+        "conda-env": os.environ.get("CONDA_DEFAULT_ENV", ""),
+        "path": os.path.realpath(os.curdir),
+        "command-line": " ".join(args),
+        "hostname": __import__("platform").node(),
+        "platform": sys.platform,
+    }
+
+
+def save_json_with_backup(path: pathlib.Path, data: dict | list) -> None:
+    """Save a dictionary into a JSON file with path checking and backup.
+
+    This function will save a dictionary into a JSON file.  It will check to
+    the existence of the directory leading to the file and create it if
+    necessary.  If the file already exists on the destination folder, it is
+    backed-up before a new file is created with the new contents.
+
+    Parameters
+    ----------
+    path
+        The full path where to save the JSON data.
+    data
+        The data to save on the JSON file.
+    """
+
+    logger.info(f"Writing run metadata at `{path}`...")
+
+    path.parent.mkdir(parents=True, exist_ok=True)
+    if path.exists():
+        backup = path.parent / (path.name + "~")
+        shutil.copy(path, backup)
+
+    with path.open("w") as f:
+        json.dump(data, f, indent=2)
diff --git a/src/mednet/libs/segmentation/utils/rc.py b/src/mednet/libs/segmentation/utils/rc.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7b3c48d83a9882f28d9392a36058e35b6405df9
--- /dev/null
+++ b/src/mednet/libs/segmentation/utils/rc.py
@@ -0,0 +1,15 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from clapper.rc import UserDefaults
+
+
+def load_rc() -> UserDefaults:
+    """Return global configuration variables.
+
+    Returns
+    -------
+        The user defaults read from the user .toml configuration file.
+    """
+    return UserDefaults("deepdraw.toml")
diff --git a/src/mednet/scripts/cli.py b/src/mednet/scripts/cli.py
index 322b8844ea9345b8ceca181f4aebf3b64b590da2..25b08c05e39ac1117bf731105ab792a572048191 100644
--- a/src/mednet/scripts/cli.py
+++ b/src/mednet/scripts/cli.py
@@ -1,6 +1,7 @@
 import click
 from clapper.click import AliasedGroup
 from mednet.libs.classification.scripts.cli import classification
+from mednet.libs.segmentation.scripts.cli import segmentation
 
 
 @click.group(
@@ -13,3 +14,4 @@ def cli():
 
 
 cli.add_command(classification)
+cli.add_command(segmentation)