diff --git a/src/mednet/config/models/alexnet.py b/src/mednet/config/models/alexnet.py
index 9703f964a476b53d5aa076242bb0b02cedfe75ff..9a1cf54207a09ea75e862ac0548640241af8a705 100644
--- a/src/mednet/config/models/alexnet.py
+++ b/src/mednet/config/models/alexnet.py
@@ -15,8 +15,8 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.alexnet import Alexnet
 
 model = Alexnet(
-    train_loss=BCEWithLogitsLoss(),
-    validation_loss=BCEWithLogitsLoss(),
+    train_loss_type=BCEWithLogitsLoss,
+    validation_loss_type=BCEWithLogitsLoss,
     optimizer_type=SGD,
     optimizer_arguments=dict(lr=0.01, momentum=0.1),
     augmentation_transforms=[ElasticDeformation(p=0.8)],
diff --git a/src/mednet/config/models/alexnet_pretrained.py b/src/mednet/config/models/alexnet_pretrained.py
index 8887db8f6f006cd2580dabac44202055a5cdacab..ea9198aba28bd67c76ec814c3456c6cf367d4c03 100644
--- a/src/mednet/config/models/alexnet_pretrained.py
+++ b/src/mednet/config/models/alexnet_pretrained.py
@@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.alexnet import Alexnet
 
 model = Alexnet(
-    train_loss=BCEWithLogitsLoss(),
-    validation_loss=BCEWithLogitsLoss(),
+    train_loss_type=BCEWithLogitsLoss,
+    validation_loss_type=BCEWithLogitsLoss,
     optimizer_type=SGD,
     optimizer_arguments=dict(lr=0.01, momentum=0.1),
     augmentation_transforms=[ElasticDeformation(p=0.8)],
diff --git a/src/mednet/config/models/densenet.py b/src/mednet/config/models/densenet.py
index f28dd23cd12c72e5fc6713e706f0e9c05158759c..7154bb740cb9db3c572f689b9fc6216e9862a041 100644
--- a/src/mednet/config/models/densenet.py
+++ b/src/mednet/config/models/densenet.py
@@ -15,8 +15,8 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.densenet import Densenet
 
 model = Densenet(
-    train_loss=BCEWithLogitsLoss(),
-    validation_loss=BCEWithLogitsLoss(),
+    train_loss_type=BCEWithLogitsLoss,
+    validation_loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
     augmentation_transforms=[ElasticDeformation(p=0.2)],
diff --git a/src/mednet/config/models/densenet_pretrained.py b/src/mednet/config/models/densenet_pretrained.py
index 274a564601094a8ecb51e67c87f19f1f8197a30a..1025a68936a97aa63c5b84029efc9e75d16ae7ef 100644
--- a/src/mednet/config/models/densenet_pretrained.py
+++ b/src/mednet/config/models/densenet_pretrained.py
@@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.densenet import Densenet
 
 model = Densenet(
-    train_loss=BCEWithLogitsLoss(),
-    validation_loss=BCEWithLogitsLoss(),
+    train_loss_type=BCEWithLogitsLoss,
+    validation_loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
     augmentation_transforms=[ElasticDeformation(p=0.2)],
diff --git a/src/mednet/config/models/densenet_rs.py b/src/mednet/config/models/densenet_rs.py
index e7db48850d0e8d2b959b39ee93bae3b78dccfa80..23c1a06468c877350b873f064b8eb274c1ee6edc 100644
--- a/src/mednet/config/models/densenet_rs.py
+++ b/src/mednet/config/models/densenet_rs.py
@@ -16,8 +16,8 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.densenet import Densenet
 
 model = Densenet(
-    train_loss=BCEWithLogitsLoss(),
-    validation_loss=BCEWithLogitsLoss(),
+    train_loss_type=BCEWithLogitsLoss,
+    validation_loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
     augmentation_transforms=[ElasticDeformation(p=0.2)],
diff --git a/src/mednet/config/models/pasa.py b/src/mednet/config/models/pasa.py
index 227b9b426568bebf327c1ac3f206a5fc3f2b44b6..16457d1a960ab0c1204fd272fc1fb95330fe4704 100644
--- a/src/mednet/config/models/pasa.py
+++ b/src/mednet/config/models/pasa.py
@@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.pasa import Pasa
 
 model = Pasa(
-    train_loss=BCEWithLogitsLoss(),
-    validation_loss=BCEWithLogitsLoss(),
+    train_loss_type=BCEWithLogitsLoss,
+    validation_loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=8e-5),
     augmentation_transforms=[ElasticDeformation(p=0.8)],
diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py
index d3a345b71d357c732075f482954e41b905591093..0993354f8dc2d84b30f0ab18f95751621a6076fb 100644
--- a/src/mednet/engine/trainer.py
+++ b/src/mednet/engine/trainer.py
@@ -72,6 +72,8 @@ def run(
 
     output_folder.mkdir(parents=True, exist_ok=True)
 
+    model.configure_losses()
+
     from .loggers import CustomTensorboardLogger
 
     log_dir = "logs"
diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py
index eada55b8bbb6c49fb43e1fd72bc31fb05b3444f1..d4e2586d2c101f0b0b04a2d000447398d0f21a4c 100644
--- a/src/mednet/models/alexnet.py
+++ b/src/mednet/models/alexnet.py
@@ -27,14 +27,16 @@ class Alexnet(Model):
 
     Parameters
     ----------
-    train_loss
+    train_loss_type
         The loss to be used during the training.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-    validation_loss
+    train_loss_arguments
+        Arguments to the training loss.
+    validation_loss_type
         The loss to be used for validation (may be different from the training
         loss).  If extra-validation sets are provided, the same loss will be
         used throughout.
@@ -43,6 +45,8 @@ class Alexnet(Model):
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
+    validation_loss_arguments
+        Arguments to the validation loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -59,8 +63,10 @@ class Alexnet(Model):
 
     def __init__(
         self,
-        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
-        validation_loss: torch.nn.Module | None = None,
+        train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        train_loss_arguments: dict[str, typing.Any] = {},
+        validation_loss_type: torch.nn.Module | None = None,
+        validation_loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
@@ -68,8 +74,10 @@ class Alexnet(Model):
         num_classes: int = 1,
     ):
         super().__init__(
-            train_loss,
-            validation_loss,
+            train_loss_type,
+            train_loss_arguments,
+            validation_loss_type,
+            validation_loss_arguments,
             optimizer_type,
             optimizer_arguments,
             augmentation_transforms,
@@ -166,7 +174,7 @@ class Alexnet(Model):
 
         # data forwarding on the existing network
         outputs = self(images)
-        return self._validation_loss[dataloader_idx](outputs, labels.float())
+        return self._validation_loss(outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         outputs = self(batch[0])
diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py
index 15da7f4eaded9618849c29a32cdaff97a9ddb9cb..c90ab10cefdb6094067b1dc00ce90c9102de4ba4 100644
--- a/src/mednet/models/densenet.py
+++ b/src/mednet/models/densenet.py
@@ -25,14 +25,16 @@ class Densenet(Model):
 
     Parameters
     ----------
-    train_loss
+    train_loss_type
         The loss to be used during the training.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-    validation_loss
+    train_loss_arguments
+        Arguments to the training loss.
+    validation_loss_type
         The loss to be used for validation (may be different from the training
         loss).  If extra-validation sets are provided, the same loss will be
         used throughout.
@@ -41,6 +43,8 @@ class Densenet(Model):
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
+    validation_loss_arguments
+        Arguments to the validation loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -59,8 +63,10 @@ class Densenet(Model):
 
     def __init__(
         self,
-        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
-        validation_loss: torch.nn.Module | None = None,
+        train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        train_loss_arguments: dict[str, typing.Any] = {},
+        validation_loss_type: torch.nn.Module | None = None,
+        validation_loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
@@ -69,8 +75,10 @@ class Densenet(Model):
         num_classes: int = 1,
     ):
         super().__init__(
-            train_loss,
-            validation_loss,
+            train_loss_type,
+            train_loss_arguments,
+            validation_loss_type,
+            validation_loss_arguments,
             optimizer_type,
             optimizer_arguments,
             augmentation_transforms,
@@ -164,7 +172,7 @@ class Densenet(Model):
         # data forwarding on the existing network
         outputs = self(images)
 
-        return self._validation_loss[dataloader_idx](outputs, labels.float())
+        return self._validation_loss(outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         outputs = self(batch[0])
diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py
index e5a13dc8336aa28335bba58d26e98d0814ac9ed8..27e5c32d78ae866651c79e377fbfaa2663c4c35a 100644
--- a/src/mednet/models/model.py
+++ b/src/mednet/models/model.py
@@ -24,14 +24,16 @@ class Model(pl.LightningModule):
 
     Parameters
     ----------
-    train_loss
+    train_loss_type
         The loss to be used during the training.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-    validation_loss
+    train_loss_arguments
+        Arguments to the training loss.
+    validation_loss_type
         The loss to be used for validation (may be different from the training
         loss).  If extra-validation sets are provided, the same loss will be
         used throughout.
@@ -40,6 +42,8 @@ class Model(pl.LightningModule):
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
+    validation_loss_arguments
+        Arguments to the validation loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -53,8 +57,10 @@ class Model(pl.LightningModule):
 
     def __init__(
         self,
-        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
-        validation_loss: torch.nn.Module | None = None,
+        train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        train_loss_arguments: dict[str, typing.Any] = {},
+        validation_loss_type: torch.nn.Module | None = None,
+        validation_loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
@@ -67,10 +73,13 @@ class Model(pl.LightningModule):
 
         self.model_transforms: TransformSequence = []
 
-        self._train_loss = train_loss
-        self._validation_loss = [
-            (validation_loss if validation_loss is not None else train_loss)
-        ]
+        self._train_loss_type = train_loss_type
+        self._train_loss_arguments = train_loss_arguments
+        self._train_loss = None
+
+        self._validation_loss_type = validation_loss_type
+        self._validation_loss_arguments = validation_loss_arguments
+        self.validation_loss = None
 
         self._optimizer_type = optimizer_type
         self._optimizer_arguments = optimizer_arguments
@@ -138,6 +147,12 @@ class Model(pl.LightningModule):
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         raise NotImplementedError
 
+    def configure_losses(self):
+        self._train_loss = self._train_loss_type(**self._train_loss_arguments)
+        self._validation_loss = self._validation_loss_type(
+            **self._validation_loss_arguments
+        )
+
     def configure_optimizers(self):
         return self._optimizer_type(
             self.parameters(),
@@ -153,44 +168,30 @@ class Model(pl.LightningModule):
             Instance of a datamodule.
         """
 
-        logger.info(f"Balancing training loss function {self._train_loss}.")
+        logger.info(
+            f"Balancing training loss function {self._train_loss_type}."
+        )
         try:
-            getattr(self._train_loss, "pos_weight")
+            getattr(self._train_loss_type(), "pos_weight")
         except AttributeError:
             logger.warning(
                 "Training loss does not posess a 'pos_weight' attribute and will not be balanced."
             )
         else:
             train_weights = get_positive_weights(datamodule.train_dataloader())
-            setattr(self._train_loss, "pos_weight", train_weights)
+            self._train_loss_arguments["pos_weight"] = train_weights
 
         logger.info(
-            f"Balancing validation loss function {self._validation_loss[0]}."
+            f"Balancing validation loss function {self._validation_loss_type}."
         )
         try:
-            getattr(self._validation_loss[0], "pos_weight")
+            getattr(self._validation_loss_type(), "pos_weight")
         except AttributeError:
             logger.warning(
                 "Validation loss does not posess a 'pos_weight' attribute and will not be balanced."
             )
         else:
-            # If multiple validation DataLoaders are used, each one will need to have a loss
-            # that is balanced for that DataLoader
-
-            new_validation_losses = []
-            loss_class = self._validation_loss[0].__class__
-
-            datamodule_validation_keys = datamodule.val_dataset_keys()
-            logger.info(
-                f"Found {len(datamodule_validation_keys)} keys in the validation datamodule. A balanced loss will be created for each key."
+            validation_weights = get_positive_weights(
+                datamodule.val_dataloader()["validation"]
             )
-
-            for val_dataset_key in datamodule_validation_keys:
-                validation_weights = get_positive_weights(
-                    datamodule.val_dataloader()[val_dataset_key]
-                )
-                new_validation_losses.append(
-                    loss_class(pos_weight=validation_weights)
-                )
-
-            self._validation_loss = new_validation_losses
+            self._validation_loss_arguments["pos_weight"] = validation_weights
diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py
index 54032edad5feee198b47274174b99e9a33fc8521..5e5cd743eae3f4abab99797137da723838c682eb 100644
--- a/src/mednet/models/pasa.py
+++ b/src/mednet/models/pasa.py
@@ -30,14 +30,16 @@ class Pasa(Model):
 
     Parameters
     ----------
-    train_loss
+    train_loss_type
         The loss to be used during the training.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-    validation_loss
+    train_loss_arguments
+        Arguments to the training loss.
+    validation_loss_type
         The loss to be used for validation (may be different from the training
         loss).  If extra-validation sets are provided, the same loss will be
         used throughout.
@@ -46,6 +48,8 @@ class Pasa(Model):
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
+    validation_loss_arguments
+        Arguments to the validation loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -59,16 +63,20 @@ class Pasa(Model):
 
     def __init__(
         self,
-        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
-        validation_loss: torch.nn.Module | None = None,
+        train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        train_loss_arguments: dict[str, typing.Any] = {},
+        validation_loss_type: torch.nn.Module | None = None,
+        validation_loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
     ):
         super().__init__(
-            train_loss,
-            validation_loss,
+            train_loss_type,
+            train_loss_arguments,
+            validation_loss_type,
+            validation_loss_arguments,
             optimizer_type,
             optimizer_arguments,
             augmentation_transforms,
@@ -233,7 +241,7 @@ class Pasa(Model):
 
         # data forwarding on the existing network
         outputs = self(images)
-        return self._validation_loss[dataloader_idx](outputs, labels.float())
+        return self._validation_loss(outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         outputs = self(batch[0])