diff --git a/src/mednet/libs/classification/models/alexnet.py b/src/mednet/libs/classification/models/alexnet.py
index b400a728d2219f19c8148bff1f39d1de88117833..320e850cd5c50bcf71a0671ef38a25f0d2faac5b 100644
--- a/src/mednet/libs/classification/models/alexnet.py
+++ b/src/mednet/libs/classification/models/alexnet.py
@@ -56,11 +56,11 @@ class Alexnet(ClassificationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -68,20 +68,18 @@ class Alexnet(ClassificationModel):
         num_classes: int = 1,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="alexnet",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
 
-        self.name = "alexnet"
-        self.num_classes = num_classes
-
         self.pretrained = pretrained
 
         # Load pretrained model
diff --git a/src/mednet/libs/classification/models/classification_model.py b/src/mednet/libs/classification/models/classification_model.py
index e2d4b3bebf7899522159e8dd1ae984e809ea58e6..51f6fa18a07b8fa886a2db0eeb0bb692e163cc3e 100644
--- a/src/mednet/libs/classification/models/classification_model.py
+++ b/src/mednet/libs/classification/models/classification_model.py
@@ -20,6 +20,8 @@ class ClassificationModel(Model):
 
     Parameters
     ----------
+    name
+        Common name to give to models of this type.
     loss_type
         The loss to be used for training and evaluation.
 
@@ -49,6 +51,7 @@ class ClassificationModel(Model):
 
     def __init__(
         self,
+        name: str,
         loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
@@ -60,6 +63,7 @@ class ClassificationModel(Model):
         num_classes: int = 1,
     ):
         super().__init__(
+            name,
             loss_type,
             loss_arguments,
             optimizer_type,
diff --git a/src/mednet/libs/classification/models/cnn3d.py b/src/mednet/libs/classification/models/cnn3d.py
index 9a77092df9113f7083a7db31b191384a71546d19..64bf4893f16c89d6f3e1eae7d1deea5d330ac09c 100644
--- a/src/mednet/libs/classification/models/cnn3d.py
+++ b/src/mednet/libs/classification/models/cnn3d.py
@@ -65,6 +65,7 @@ class Conv3DNet(ClassificationModel):
         num_classes: int = 1,
     ):
         super().__init__(
+            name="cnn3d",
             loss_type=loss_type,
             loss_arguments=loss_arguments,
             optimizer_type=optimizer_type,
diff --git a/src/mednet/libs/classification/models/densenet.py b/src/mednet/libs/classification/models/densenet.py
index e58446d1e04794fa4ad329050176385f21a3e798..e2e92793ad6a101d2fab498484a01d98e11c1aa9 100644
--- a/src/mednet/libs/classification/models/densenet.py
+++ b/src/mednet/libs/classification/models/densenet.py
@@ -56,11 +56,11 @@ class Densenet(ClassificationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -69,20 +69,18 @@ class Densenet(ClassificationModel):
         num_classes: int = 1,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="densenet-121",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
 
-        self.name = "densenet-121"
-        self.num_classes = num_classes
-
         self.pretrained = pretrained
 
         # Load pretrained model
diff --git a/src/mednet/libs/classification/models/logistic_regression.py b/src/mednet/libs/classification/models/logistic_regression.py
index f9fe1847be16de65dc3dfbd88223b89db6eeec97..d95bce0698be0b51ce86346a0e69f7426089c590 100644
--- a/src/mednet/libs/classification/models/logistic_regression.py
+++ b/src/mednet/libs/classification/models/logistic_regression.py
@@ -4,109 +4,59 @@
 
 import typing
 
-import lightning.pytorch as pl
 import torch
-import torch.nn as nn
+import torch.nn
 
+from .classification_model import ClassificationModel
 
-class LogisticRegression(pl.LightningModule):
+
+class LogisticRegression(ClassificationModel):
     """Logistic regression classifier with a single output.
 
     Parameters
     ----------
-    train_loss
-        The loss to be used during the training.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    num_classes
+        Number of outputs (classes) for this model.
     input_size
         The number of inputs this classifer shall process.
     """
 
     def __init__(
         self,
-        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
-        validation_loss: torch.nn.Module | None = None,
+        loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {"lr": 1e-2},
+        num_classes: int = 1,
         input_size: int = 14,
     ):
-        super().__init__()
-
-        self._train_loss = train_loss.to(self.device)
-        self._validation_loss = (
-            validation_loss if validation_loss is not None else train_loss
-        ).to(self.device)
-        self._optimizer_type = optimizer_type
-        self._optimizer_arguments = optimizer_arguments
-
-        self.name = "logistic-regression"
+        super().__init__(
+            name="logistic-regression",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=None,
+            scheduler_arguments={},
+            model_transforms=[],
+            augmentation_transforms=[],
+            num_classes=num_classes,
+        )
 
-        self.linear = nn.Linear(input_size, 1)
+        self.linear = torch.nn.Linear(input_size, self.num_classes)
 
     def forward(self, x):
         return self.linear(self.normalizer(x))
-
-    def training_step(self, batch, batch_idx):
-        _input = batch[1]
-        labels = batch[2]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # Forward pass on the network
-        output = self(_input)
-
-        # Manually move criterion to selected device, since not part of the model.
-        self._train_loss = self._train_loss.to(self.device)
-        training_loss = self._train_loss(output, labels.float())
-
-        return {"loss": training_loss}
-
-    def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        _input = batch[1]
-        labels = batch[2]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # data forwarding on the existing network
-        output = self(_input)
-
-        # Manually move criterion to selected device, since not part of the model.
-        self._validation_loss = self._validation_loss.to(self.device)
-        validation_loss = self._validation_loss(output, labels.float())
-
-        if dataloader_idx == 0:
-            return {"validation_loss": validation_loss}
-
-        return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        outputs = self(batch[0])
-        return torch.sigmoid(outputs)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(
-            self.parameters(),
-            **self._optimizer_arguments,
-        )
diff --git a/src/mednet/libs/classification/models/mlp.py b/src/mednet/libs/classification/models/mlp.py
index bd928410a3133141a0048de8537df4bf65cd8e49..b35ca1036607961a1cff3b4d68f170e58819833f 100644
--- a/src/mednet/libs/classification/models/mlp.py
+++ b/src/mednet/libs/classification/models/mlp.py
@@ -4,35 +4,32 @@
 
 import typing
 
-import lightning.pytorch as pl
 import torch
+import torch.nn
 
+from .classification_model import ClassificationModel
 
-class MultiLayerPerceptron(pl.LightningModule):
+
+class MultiLayerPerceptron(ClassificationModel):
     """MLP with a variable number of inputs and hidden neurons (single layer).
 
     Parameters
     ----------
-    train_loss
-        The loss to be used during the training.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    num_classes
+        Number of outputs (classes) for this model.
     input_size
         The number of inputs this classifer shall process.
     hidden_size
@@ -41,76 +38,30 @@ class MultiLayerPerceptron(pl.LightningModule):
 
     def __init__(
         self,
-        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
-        validation_loss: torch.nn.Module | None = None,
+        loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {"lr": 1e-2},
+        num_classes: int = 1,
         input_size: int = 14,
         hidden_size: int = 10,
     ):
-        super().__init__()
-
-        self._train_loss = train_loss.to(self.device)
-        self._validation_loss = (
-            validation_loss if validation_loss is not None else train_loss
-        ).to(self.device)
-        self._optimizer_type = optimizer_type
-        self._optimizer_arguments = optimizer_arguments
-
-        self.name = "mlp"
+        super().__init__(
+            name="mlp",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=None,
+            scheduler_arguments={},
+            model_transforms=[],
+            augmentation_transforms=[],
+            num_classes=num_classes,
+        )
 
         self.fc1 = torch.nn.Linear(input_size, hidden_size)
         self.relu = torch.nn.ReLU()
-        self.fc2 = torch.nn.Linear(hidden_size, 1)
+        self.fc2 = torch.nn.Linear(hidden_size, self.num_classes)
 
     def forward(self, x):
         return self.fc2(self.relu(self.fc1(self.normalizer(x))))
-
-    def training_step(self, batch, batch_idx):
-        _input = batch[1]
-        labels = batch[2]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # Forward pass on the network
-        output = self(_input)
-
-        # Manually move criterion to selected device, since not part of the model.
-        self._train_loss = self._train_loss.to(self.device)
-        training_loss = self._train_loss(output, labels.float())
-
-        return {"loss": training_loss}
-
-    def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        _input = batch[1]
-        labels = batch[2]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # data forwarding on the existing network
-        output = self(_input)
-
-        # Manually move criterion to selected device, since not part of the model.
-        self._validation_loss = self._validation_loss.to(self.device)
-        validation_loss = self._validation_loss(output, labels.float())
-
-        if dataloader_idx == 0:
-            return {"validation_loss": validation_loss}
-
-        return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        outputs = self(batch[0])
-        return torch.sigmoid(outputs)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(
-            self.parameters(),
-            **self._optimizer_arguments,
-        )
diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py
index 8eea2c38e3bff899ac9686b6aba695df04c687b9..6f0c7c75763034dd462c749c9b958d790fcc4d2f 100644
--- a/src/mednet/libs/classification/models/pasa.py
+++ b/src/mednet/libs/classification/models/pasa.py
@@ -56,31 +56,29 @@ class Pasa(ClassificationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="pasa",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
 
-        self.name = "pasa"
-        self.num_classes = num_classes
-
         # First convolution block
         self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
         self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
diff --git a/src/mednet/libs/common/engine/trainer.py b/src/mednet/libs/common/engine/trainer.py
index 5d3f2a0de703aa97faa055b2c269d1e3b816eba1..63d916bc6043f76ac27f52b99309953c463d0fe1 100644
--- a/src/mednet/libs/common/engine/trainer.py
+++ b/src/mednet/libs/common/engine/trainer.py
@@ -73,8 +73,6 @@ def run(
 
     output_folder.mkdir(parents=True, exist_ok=True)
 
-    model.configure_losses()
-
     from .loggers import CustomTensorboardLogger
 
     log_dir = "logs"
diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py
index 669ba6c18b35faf30876780c14fa30f6368846ef..cad520975ce784f6183c30b427c0db08ebfaeaab 100644
--- a/src/mednet/libs/common/models/model.py
+++ b/src/mednet/libs/common/models/model.py
@@ -2,6 +2,7 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import copy
 import logging
 import typing
 
@@ -25,6 +26,8 @@ class Model(pl.LightningModule):
 
     Parameters
     ----------
+    name
+        Common name to give to models of this type.
     loss_type
         The loss to be used for training and evaluation.
 
@@ -54,6 +57,7 @@ class Model(pl.LightningModule):
 
     def __init__(
         self,
+        name: str,
         loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
@@ -66,25 +70,21 @@ class Model(pl.LightningModule):
     ):
         super().__init__()
 
-        self.name = "model"
+        self.name = name
         self.num_classes = num_classes
-
         self.model_transforms = model_transforms
-
         self._loss_type = loss_type
-
-        self._train_loss_arguments = loss_arguments
-
-        self._validation_loss_arguments = loss_arguments
-
+        self._train_loss_arguments = copy.deepcopy(loss_arguments)
+        self._validation_loss_arguments = copy.deepcopy(loss_arguments)
         self._optimizer_type = optimizer_type
         self._optimizer_arguments = optimizer_arguments
-
         self._scheduler_type = scheduler_type
         self._scheduler_arguments = scheduler_arguments
-
         self.augmentation_transforms = augmentation_transforms
 
+        # initializes losses from input arguments
+        self._configure_losses()
+
     @property
     def augmentation_transforms(self) -> torchvision.transforms.Compose:
         return self._augmentation_transforms
@@ -149,8 +149,14 @@ class Model(pl.LightningModule):
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         raise NotImplementedError
 
-    def configure_losses(self):
+    def _configure_losses(self):
+        """Create loss objects for train and validation."""
+
+        logger.info(f"Configuring train loss ({self._train_loss_arguments})...")
         self._train_loss = self._loss_type(**self._train_loss_arguments)
+        logger.info(
+            f"Configuring validation loss ({self._validation_loss_arguments})..."
+        )
         self._validation_loss = self._loss_type(**self._validation_loss_arguments)
 
     def configure_optimizers(self):
@@ -195,10 +201,8 @@ class Model(pl.LightningModule):
             ]
         )
 
-        if self._train_loss is not None:
-            self._train_loss.to(*args, **kwargs)
-        if self._validation_loss is not None:
-            self._validation_loss.to(*args, **kwargs)
+        self._train_loss.to(*args, **kwargs)
+        self._validation_loss.to(*args, **kwargs)
 
         return self
 
@@ -213,23 +217,36 @@ class Model(pl.LightningModule):
 
         try:
             getattr(self._loss_type(), "pos_weight")
+
         except AttributeError:
             logger.warning(
                 f"Loss {self._loss_type} does not posess a 'pos_weight' attribute and will not be balanced."
             )
+
         else:
-            logger.info(f"Balancing training loss {self._loss_type}.")
             train_weights = get_positive_weights(datamodule.train_dataloader())
             self._train_loss_arguments["pos_weight"] = train_weights
+            logger.info(
+                f"Balanced training loss {self._loss_type}: "
+                f"`pos_weight={train_weights.item():.3f}`."
+            )
 
-            logger.info(f"Balancing validation loss {self._loss_type}.")
             if "validation" in datamodule.val_dataloader().keys():
                 validation_weights = get_positive_weights(
                     datamodule.val_dataloader()["validation"]
                 )
             else:
                 logger.warning(
-                    "Datamodule does not contain a validation dataloader. The training dataloader will be used instead."
+                    "Datamodule does not contain a validation dataloader. "
+                    "The training dataloader will be used instead."
                 )
                 validation_weights = get_positive_weights(datamodule.train_dataloader())
+
             self._validation_loss_arguments["pos_weight"] = validation_weights
+            logger.info(
+                f"Balanced validation loss {self._loss_type}: "
+                f"`pos_weight={validation_weights.item():.3f}`."
+            )
+
+        # re-instantiates losses for train and validation
+        self._configure_losses()
diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py
index 295713cb5ec4c22ae4ab391a88d21af41cf6d592..40d097096179005e3cd2608ee2b6d52197ced218 100644
--- a/src/mednet/libs/segmentation/models/driu.py
+++ b/src/mednet/libs/segmentation/models/driu.py
@@ -8,6 +8,7 @@ import typing
 
 import torch
 import torch.nn
+import torch.utils.data
 from mednet.libs.common.data.typing import TransformSequence
 
 from .backbones.vgg import vgg16_for_segmentation
@@ -43,7 +44,7 @@ class DRIUHead(torch.nn.Module):
         Number of channels for each feature map that is returned from backbone.
     """
 
-    def __init__(self, in_channels_list=None):
+    def __init__(self, in_channels_list):
         super().__init__()
         (
             in_conv_1_2_16,
@@ -106,11 +107,11 @@ class DRIU(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -118,17 +119,17 @@ class DRIU(SegmentationModel):
         pretrained: bool = False,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="driu",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
-        self.name = "driu"
 
         self.pretrained = pretrained
 
diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py
index bacbf2913f02b54f5b00b31e35600416e6a59e9b..9e32925e1db8bf824dd7359b1c2ceac2ac8ea5a8 100644
--- a/src/mednet/libs/segmentation/models/driu_bn.py
+++ b/src/mednet/libs/segmentation/models/driu_bn.py
@@ -8,6 +8,7 @@ import typing
 
 import torch
 import torch.nn
+import torch.utils.data
 from mednet.libs.common.data.typing import TransformSequence
 
 from .backbones.vgg import vgg16_for_segmentation
@@ -46,7 +47,7 @@ class DRIUBNHead(torch.nn.Module):
         Number of channels for each feature map that is returned from backbone.
     """
 
-    def __init__(self, in_channels_list=None):
+    def __init__(self, in_channels_list):
         super().__init__()
         (
             in_conv_1_2_16,
@@ -109,11 +110,11 @@ class DRIUBN(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -121,17 +122,17 @@ class DRIUBN(SegmentationModel):
         pretrained: bool = False,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="driu-bn",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
-        self.name = "driu-bn"
 
         self.pretrained = pretrained
 
diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py
index af4727c556b112020739a37337f724afe3ed1b24..2336869e5f5e881affecb7ecb981224831fe0447 100644
--- a/src/mednet/libs/segmentation/models/driu_od.py
+++ b/src/mednet/libs/segmentation/models/driu_od.py
@@ -28,7 +28,7 @@ class DRIUODHead(torch.nn.Module):
         Number of channels for each feature map that is returned from backbone.
     """
 
-    def __init__(self, in_channels_list=None):
+    def __init__(self, in_channels_list):
         super().__init__()
         (
             in_upsample2,
@@ -91,11 +91,11 @@ class DRIUOD(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -103,17 +103,17 @@ class DRIUOD(SegmentationModel):
         pretrained: bool = False,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="driu-od",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
-        self.name = "driu-od"
 
         self.pretrained = pretrained
 
diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py
index e9800984a884106e36a2aa96238bad48b784dc49..42436fcc1d5eb72c7b5d7b7511561207a5a3c68f 100644
--- a/src/mednet/libs/segmentation/models/driu_pix.py
+++ b/src/mednet/libs/segmentation/models/driu_pix.py
@@ -8,6 +8,7 @@ import typing
 
 import torch
 import torch.nn
+import torch.utils.data
 from mednet.libs.common.data.typing import TransformSequence
 
 from .backbones.vgg import vgg16_for_segmentation
@@ -28,7 +29,7 @@ class DRIUPIXHead(torch.nn.Module):
         Number of channels for each feature map that is returned from backbone.
     """
 
-    def __init__(self, in_channels_list=None):
+    def __init__(self, in_channels_list):
         super().__init__()
         (
             in_conv_1_2_16,
@@ -95,11 +96,11 @@ class DRIUPix(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -107,17 +108,17 @@ class DRIUPix(SegmentationModel):
         pretrained: bool = False,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="driu-pix",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
-        self.name = "driu-pix"
 
         self.pretrained = pretrained
 
diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py
index 12a7c591aec5d75fb971b9265a27ade8cf51806f..cd99707da2194659b2f5c5621bac4ea05015b970 100644
--- a/src/mednet/libs/segmentation/models/hed.py
+++ b/src/mednet/libs/segmentation/models/hed.py
@@ -7,6 +7,7 @@ import typing
 
 import torch
 import torch.nn
+import torch.utils.data
 from mednet.libs.common.data.typing import TransformSequence
 
 from .backbones.vgg import vgg16_for_segmentation
@@ -40,7 +41,7 @@ class HEDHead(torch.nn.Module):
         Number of channels for each feature map that is returned from backbone.
     """
 
-    def __init__(self, in_channels_list=None):
+    def __init__(self, in_channels_list):
         super().__init__()
         (
             in_conv_1_2_16,
@@ -109,11 +110,11 @@ class HED(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = MultiSoftJaccardBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = MultiSoftJaccardBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -121,19 +122,18 @@ class HED(SegmentationModel):
         pretrained: bool = False,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="hed",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
 
-        self.name = "hed"
-
         self.pretrained = pretrained
 
         self.backbone = vgg16_for_segmentation(
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index 9442c2c92c1e1d1f32d2e02406a9101e2688796a..5cd491b04d392bbf8579bbca3525e4ea6f46bd45 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -310,31 +310,29 @@ class LittleWNet(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = MultiWeightedBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = MultiWeightedBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="lwnet",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
 
-        self.name = "lwnet"
-        self.num_classes = num_classes
-
         self.unet1 = LittleUNet(
             in_c=3,
             n_classes=self.num_classes,
diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py
index fcca7d6072a88ad6bc6bccd4932eef2a1b0a153f..8735956595b3737b7320a7f76ac26db37fe35c0b 100644
--- a/src/mednet/libs/segmentation/models/m2unet.py
+++ b/src/mednet/libs/segmentation/models/m2unet.py
@@ -7,6 +7,7 @@ import typing
 
 import torch
 import torch.nn
+import torch.utils.data
 from mednet.libs.common.data.typing import TransformSequence
 from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
 from torchvision.models.mobilenetv2 import InvertedResidual
@@ -157,11 +158,11 @@ class M2UNET(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -169,19 +170,18 @@ class M2UNET(SegmentationModel):
         pretrained: bool = False,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="m2unet",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
 
-        self.name = "m2unet"
-
         self.pretrained = pretrained
 
         self.backbone = mobilenet_v2_for_segmentation(
diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py
index 8ffb75ac7fed41340991deef8636acaa9f2fcbae..dae5cde663ad0aebef3302b8a68210efc396beac 100644
--- a/src/mednet/libs/segmentation/models/segmentation_model.py
+++ b/src/mednet/libs/segmentation/models/segmentation_model.py
@@ -21,6 +21,8 @@ class SegmentationModel(Model):
 
     Parameters
     ----------
+    name
+        Common name to give to models of this type.
     loss_type
         The loss to be used for training and evaluation.
 
@@ -50,6 +52,7 @@ class SegmentationModel(Model):
 
     def __init__(
         self,
+        name: str,
         loss_type: type[torch.nn.Module] = MultiWeightedBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
@@ -61,6 +64,7 @@ class SegmentationModel(Model):
         num_classes: int = 1,
     ):
         super().__init__(
+            name,
             loss_type,
             loss_arguments,
             optimizer_type,
diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py
index cb1e34cfa7c07fcbafa43a59269412703f4be3ef..f4d64f01658c9303ad39a71bcc1065d1e0ecd2fc 100644
--- a/src/mednet/libs/segmentation/models/unet.py
+++ b/src/mednet/libs/segmentation/models/unet.py
@@ -7,6 +7,7 @@ import logging
 import typing
 
 import torch.nn
+import torch.utils.data
 from mednet.libs.common.data.typing import TransformSequence
 
 from .backbones.vgg import vgg16_for_segmentation
@@ -28,7 +29,7 @@ class UNetHead(torch.nn.Module):
         If True, upsample using PixelShuffleICNR.
     """
 
-    def __init__(self, in_channels_list: list[int] = None, pixel_shuffle=False):
+    def __init__(self, in_channels_list: list[int], pixel_shuffle=False):
         super().__init__()
         # number of channels
         c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list
@@ -98,11 +99,11 @@ class Unet(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -110,19 +111,18 @@ class Unet(SegmentationModel):
         pretrained: bool = False,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="unet",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
 
-        self.name = "unet"
-
         self.pretrained = pretrained
 
         self.backbone = vgg16_for_segmentation(
@@ -160,26 +160,3 @@ class Unet(SegmentationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)