diff --git a/src/mednet/config/models/alexnet.py b/src/mednet/config/models/alexnet.py
index 9a1cf54207a09ea75e862ac0548640241af8a705..7f28186750e1e21d1f801cbb7d21d17da5ca2011 100644
--- a/src/mednet/config/models/alexnet.py
+++ b/src/mednet/config/models/alexnet.py
@@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.alexnet import Alexnet
 
 model = Alexnet(
-    train_loss_type=BCEWithLogitsLoss,
-    validation_loss_type=BCEWithLogitsLoss,
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=SGD,
     optimizer_arguments=dict(lr=0.01, momentum=0.1),
     augmentation_transforms=[ElasticDeformation(p=0.8)],
diff --git a/src/mednet/config/models/alexnet_pretrained.py b/src/mednet/config/models/alexnet_pretrained.py
index ea9198aba28bd67c76ec814c3456c6cf367d4c03..a935655555a004cbe3c5b8e2b19f77458e952e40 100644
--- a/src/mednet/config/models/alexnet_pretrained.py
+++ b/src/mednet/config/models/alexnet_pretrained.py
@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.alexnet import Alexnet
 
 model = Alexnet(
-    train_loss_type=BCEWithLogitsLoss,
-    validation_loss_type=BCEWithLogitsLoss,
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=SGD,
     optimizer_arguments=dict(lr=0.01, momentum=0.1),
     augmentation_transforms=[ElasticDeformation(p=0.8)],
diff --git a/src/mednet/config/models/densenet.py b/src/mednet/config/models/densenet.py
index 7154bb740cb9db3c572f689b9fc6216e9862a041..9ee510ac8df93713b995f857bf5afe2cb68b89a6 100644
--- a/src/mednet/config/models/densenet.py
+++ b/src/mednet/config/models/densenet.py
@@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.densenet import Densenet
 
 model = Densenet(
-    train_loss_type=BCEWithLogitsLoss,
-    validation_loss_type=BCEWithLogitsLoss,
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
     augmentation_transforms=[ElasticDeformation(p=0.2)],
diff --git a/src/mednet/config/models/densenet_pretrained.py b/src/mednet/config/models/densenet_pretrained.py
index 1025a68936a97aa63c5b84029efc9e75d16ae7ef..b7e2efcdfa83e1b70a466dbca0ddca02cf4695dc 100644
--- a/src/mednet/config/models/densenet_pretrained.py
+++ b/src/mednet/config/models/densenet_pretrained.py
@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.densenet import Densenet
 
 model = Densenet(
-    train_loss_type=BCEWithLogitsLoss,
-    validation_loss_type=BCEWithLogitsLoss,
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
     augmentation_transforms=[ElasticDeformation(p=0.2)],
diff --git a/src/mednet/config/models/densenet_rs.py b/src/mednet/config/models/densenet_rs.py
index 23c1a06468c877350b873f064b8eb274c1ee6edc..813bb76cf92b3abe105e7095085d7e01de4fbecd 100644
--- a/src/mednet/config/models/densenet_rs.py
+++ b/src/mednet/config/models/densenet_rs.py
@@ -16,8 +16,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.densenet import Densenet
 
 model = Densenet(
-    train_loss_type=BCEWithLogitsLoss,
-    validation_loss_type=BCEWithLogitsLoss,
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
     augmentation_transforms=[ElasticDeformation(p=0.2)],
diff --git a/src/mednet/config/models/pasa.py b/src/mednet/config/models/pasa.py
index 16457d1a960ab0c1204fd272fc1fb95330fe4704..7787d10e32cfad9ece6d42cee8be6bc0bb86124f 100644
--- a/src/mednet/config/models/pasa.py
+++ b/src/mednet/config/models/pasa.py
@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.pasa import Pasa
 
 model = Pasa(
-    train_loss_type=BCEWithLogitsLoss,
-    validation_loss_type=BCEWithLogitsLoss,
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=8e-5),
     augmentation_transforms=[ElasticDeformation(p=0.8)],
diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py
index d4e2586d2c101f0b0b04a2d000447398d0f21a4c..75223c9a4b78e196d0c81dab20b566c4b5d32d4b 100644
--- a/src/mednet/models/alexnet.py
+++ b/src/mednet/models/alexnet.py
@@ -27,26 +27,15 @@ class Alexnet(Model):
 
     Parameters
     ----------
-    train_loss_type
-        The loss to be used during the training.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-    train_loss_arguments
-        Arguments to the training loss.
-    validation_loss_type
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss_arguments
-        Arguments to the validation loss.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -63,10 +52,8 @@ class Alexnet(Model):
 
     def __init__(
         self,
-        train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
-        train_loss_arguments: dict[str, typing.Any] = {},
-        validation_loss_type: torch.nn.Module | None = None,
-        validation_loss_arguments: dict[str, typing.Any] = {},
+        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
@@ -74,10 +61,8 @@ class Alexnet(Model):
         num_classes: int = 1,
     ):
         super().__init__(
-            train_loss_type,
-            train_loss_arguments,
-            validation_loss_type,
-            validation_loss_arguments,
+            loss_type,
+            loss_arguments,
             optimizer_type,
             optimizer_arguments,
             augmentation_transforms,
diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py
index c90ab10cefdb6094067b1dc00ce90c9102de4ba4..76df1ed64a7a73e99b86d18145169dd600601044 100644
--- a/src/mednet/models/densenet.py
+++ b/src/mednet/models/densenet.py
@@ -25,26 +25,15 @@ class Densenet(Model):
 
     Parameters
     ----------
-    train_loss_type
-        The loss to be used during the training.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-    train_loss_arguments
-        Arguments to the training loss.
-    validation_loss_type
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss_arguments
-        Arguments to the validation loss.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -63,10 +52,8 @@ class Densenet(Model):
 
     def __init__(
         self,
-        train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
-        train_loss_arguments: dict[str, typing.Any] = {},
-        validation_loss_type: torch.nn.Module | None = None,
-        validation_loss_arguments: dict[str, typing.Any] = {},
+        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
@@ -75,10 +62,8 @@ class Densenet(Model):
         num_classes: int = 1,
     ):
         super().__init__(
-            train_loss_type,
-            train_loss_arguments,
-            validation_loss_type,
-            validation_loss_arguments,
+            loss_type,
+            loss_arguments,
             optimizer_type,
             optimizer_arguments,
             augmentation_transforms,
diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py
index 27e5c32d78ae866651c79e377fbfaa2663c4c35a..57c1b618b1c0943496a4e799552989c03f2e3532 100644
--- a/src/mednet/models/model.py
+++ b/src/mednet/models/model.py
@@ -24,26 +24,15 @@ class Model(pl.LightningModule):
 
     Parameters
     ----------
-    train_loss_type
-        The loss to be used during the training.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-    train_loss_arguments
-        Arguments to the training loss.
-    validation_loss_type
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss_arguments
-        Arguments to the validation loss.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -57,10 +46,8 @@ class Model(pl.LightningModule):
 
     def __init__(
         self,
-        train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
-        train_loss_arguments: dict[str, typing.Any] = {},
-        validation_loss_type: torch.nn.Module | None = None,
-        validation_loss_arguments: dict[str, typing.Any] = {},
+        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
@@ -73,13 +60,13 @@ class Model(pl.LightningModule):
 
         self.model_transforms: TransformSequence = []
 
-        self._train_loss_type = train_loss_type
-        self._train_loss_arguments = train_loss_arguments
+        self._loss_type = loss_type
+
         self._train_loss = None
+        self._train_loss_arguments = loss_arguments
 
-        self._validation_loss_type = validation_loss_type
-        self._validation_loss_arguments = validation_loss_arguments
         self.validation_loss = None
+        self._validation_loss_arguments = loss_arguments
 
         self._optimizer_type = optimizer_type
         self._optimizer_arguments = optimizer_arguments
@@ -148,8 +135,8 @@ class Model(pl.LightningModule):
         raise NotImplementedError
 
     def configure_losses(self):
-        self._train_loss = self._train_loss_type(**self._train_loss_arguments)
-        self._validation_loss = self._validation_loss_type(
+        self._train_loss = self._loss_type(**self._train_loss_arguments)
+        self._validation_loss = self._loss_type(
             **self._validation_loss_arguments
         )
 
@@ -160,7 +147,7 @@ class Model(pl.LightningModule):
         )
 
     def balance_losses(self, datamodule) -> None:
-        """Balance the loss based on the distribution of targets in the datamodule, if the loss function supports it.
+        """Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute).
 
         Parameters
         ----------
@@ -168,29 +155,18 @@ class Model(pl.LightningModule):
             Instance of a datamodule.
         """
 
-        logger.info(
-            f"Balancing training loss function {self._train_loss_type}."
-        )
         try:
-            getattr(self._train_loss_type(), "pos_weight")
+            getattr(self._loss_type(), "pos_weight")
         except AttributeError:
             logger.warning(
-                "Training loss does not posess a 'pos_weight' attribute and will not be balanced."
+                f"Loss {self._loss_type} does not posess a 'pos_weight' attribute and will not be balanced."
             )
         else:
+            logger.info(f"Balancing training loss {self._loss_type}.")
             train_weights = get_positive_weights(datamodule.train_dataloader())
             self._train_loss_arguments["pos_weight"] = train_weights
 
-        logger.info(
-            f"Balancing validation loss function {self._validation_loss_type}."
-        )
-        try:
-            getattr(self._validation_loss_type(), "pos_weight")
-        except AttributeError:
-            logger.warning(
-                "Validation loss does not posess a 'pos_weight' attribute and will not be balanced."
-            )
-        else:
+            logger.info(f"Balancing validation loss {self._loss_type}.")
             validation_weights = get_positive_weights(
                 datamodule.val_dataloader()["validation"]
             )
diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py
index 5e5cd743eae3f4abab99797137da723838c682eb..e9147683b08f8d8b532396544eece7829c23fd92 100644
--- a/src/mednet/models/pasa.py
+++ b/src/mednet/models/pasa.py
@@ -30,26 +30,15 @@ class Pasa(Model):
 
     Parameters
     ----------
-    train_loss_type
-        The loss to be used during the training.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-    train_loss_arguments
-        Arguments to the training loss.
-    validation_loss_type
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss_arguments
-        Arguments to the validation loss.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -63,20 +52,16 @@ class Pasa(Model):
 
     def __init__(
         self,
-        train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
-        train_loss_arguments: dict[str, typing.Any] = {},
-        validation_loss_type: torch.nn.Module | None = None,
-        validation_loss_arguments: dict[str, typing.Any] = {},
+        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
     ):
         super().__init__(
-            train_loss_type,
-            train_loss_arguments,
-            validation_loss_type,
-            validation_loss_arguments,
+            loss_type,
+            loss_arguments,
             optimizer_type,
             optimizer_arguments,
             augmentation_transforms,