diff --git a/doc/api.rst b/doc/api.rst
index 186f9eeab1aebde1081615f24fabd9971a70777d..f493d80b319435caabdd2e9f409c6bdcf0c2102f 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -45,6 +45,7 @@ CNN and other models implemented.
    mednet.models.logistic_regression
    mednet.models.loss_weights
    mednet.models.mlp
+   mednet.models.model
    mednet.models.normalizer
    mednet.models.separate
    mednet.models.transforms
diff --git a/src/mednet/config/models/alexnet.py b/src/mednet/config/models/alexnet.py
index 9703f964a476b53d5aa076242bb0b02cedfe75ff..7f28186750e1e21d1f801cbb7d21d17da5ca2011 100644
--- a/src/mednet/config/models/alexnet.py
+++ b/src/mednet/config/models/alexnet.py
@@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.alexnet import Alexnet
 
 model = Alexnet(
-    train_loss=BCEWithLogitsLoss(),
-    validation_loss=BCEWithLogitsLoss(),
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=SGD,
     optimizer_arguments=dict(lr=0.01, momentum=0.1),
     augmentation_transforms=[ElasticDeformation(p=0.8)],
diff --git a/src/mednet/config/models/alexnet_pretrained.py b/src/mednet/config/models/alexnet_pretrained.py
index 8887db8f6f006cd2580dabac44202055a5cdacab..a935655555a004cbe3c5b8e2b19f77458e952e40 100644
--- a/src/mednet/config/models/alexnet_pretrained.py
+++ b/src/mednet/config/models/alexnet_pretrained.py
@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.alexnet import Alexnet
 
 model = Alexnet(
-    train_loss=BCEWithLogitsLoss(),
-    validation_loss=BCEWithLogitsLoss(),
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=SGD,
     optimizer_arguments=dict(lr=0.01, momentum=0.1),
     augmentation_transforms=[ElasticDeformation(p=0.8)],
diff --git a/src/mednet/config/models/densenet.py b/src/mednet/config/models/densenet.py
index f28dd23cd12c72e5fc6713e706f0e9c05158759c..9ee510ac8df93713b995f857bf5afe2cb68b89a6 100644
--- a/src/mednet/config/models/densenet.py
+++ b/src/mednet/config/models/densenet.py
@@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.densenet import Densenet
 
 model = Densenet(
-    train_loss=BCEWithLogitsLoss(),
-    validation_loss=BCEWithLogitsLoss(),
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
     augmentation_transforms=[ElasticDeformation(p=0.2)],
diff --git a/src/mednet/config/models/densenet_pretrained.py b/src/mednet/config/models/densenet_pretrained.py
index 274a564601094a8ecb51e67c87f19f1f8197a30a..b7e2efcdfa83e1b70a466dbca0ddca02cf4695dc 100644
--- a/src/mednet/config/models/densenet_pretrained.py
+++ b/src/mednet/config/models/densenet_pretrained.py
@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.densenet import Densenet
 
 model = Densenet(
-    train_loss=BCEWithLogitsLoss(),
-    validation_loss=BCEWithLogitsLoss(),
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
     augmentation_transforms=[ElasticDeformation(p=0.2)],
diff --git a/src/mednet/config/models/densenet_rs.py b/src/mednet/config/models/densenet_rs.py
index e7db48850d0e8d2b959b39ee93bae3b78dccfa80..813bb76cf92b3abe105e7095085d7e01de4fbecd 100644
--- a/src/mednet/config/models/densenet_rs.py
+++ b/src/mednet/config/models/densenet_rs.py
@@ -16,8 +16,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.densenet import Densenet
 
 model = Densenet(
-    train_loss=BCEWithLogitsLoss(),
-    validation_loss=BCEWithLogitsLoss(),
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
     augmentation_transforms=[ElasticDeformation(p=0.2)],
diff --git a/src/mednet/config/models/pasa.py b/src/mednet/config/models/pasa.py
index 227b9b426568bebf327c1ac3f206a5fc3f2b44b6..7787d10e32cfad9ece6d42cee8be6bc0bb86124f 100644
--- a/src/mednet/config/models/pasa.py
+++ b/src/mednet/config/models/pasa.py
@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.pasa import Pasa
 
 model = Pasa(
-    train_loss=BCEWithLogitsLoss(),
-    validation_loss=BCEWithLogitsLoss(),
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=8e-5),
     augmentation_transforms=[ElasticDeformation(p=0.8)],
diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py
index c73bb27932c52e2061134a7f007256c7f6292161..6c7d759f4a8890f576af327450ac89e144f6fbfa 100644
--- a/src/mednet/data/datamodule.py
+++ b/src/mednet/data/datamodule.py
@@ -481,10 +481,6 @@ class ConcatDataModule(lightning.LightningDataModule):
         for CPU memory.  Sufficient CPU memory must be available before you set
         this attribute to ``True``.  It is typically useful for relatively small
         datasets.
-    balance_sampler_by_class
-        If set, then modifies the random sampler used during training and
-        validation to balance sample picking probability, making sample
-        across classes **and** datasets equitable.
     batch_size
         Number of samples in every **training** batch (this parameter affects
         memory requirements for the network).  If the number of samples in the
@@ -529,7 +525,6 @@ class ConcatDataModule(lightning.LightningDataModule):
         database_name: str = "",
         split_name: str = "",
         cache_samples: bool = False,
-        balance_sampler_by_class: bool = False,
         batch_size: int = 1,
         batch_chunk_count: int = 1,
         drop_incomplete_batch: bool = False,
@@ -552,7 +547,6 @@ class ConcatDataModule(lightning.LightningDataModule):
 
         self.cache_samples = cache_samples
         self._train_sampler = None
-        self.balance_sampler_by_class = balance_sampler_by_class
 
         self._model_transforms: list[Transform] | None = None
 
@@ -667,40 +661,6 @@ class ConcatDataModule(lightning.LightningDataModule):
             )
             self._datasets = {}
 
-    @property
-    def balance_sampler_by_class(self) -> bool:
-        """Whether to balance samples across labels/datasets.
-
-        If set, then modifies the random sampler used during training
-        and validation to balance sample picking probability, making
-        sample across classes **and** datasets equitable.
-
-        .. warning::
-
-           This method does **NOT** balance the sampler per dataset, in case
-           multiple datasets compose the same training set. It only balances
-           samples acording to their ground-truth (labels).  If you'd like to
-           have samples balanced per dataset, then implement your own data
-           module inheriting from this one.
-
-        Returns
-        -------
-        bool
-            True if self._train_sample is set, else False.
-        """
-        return self._train_sampler is not None
-
-    @balance_sampler_by_class.setter
-    def balance_sampler_by_class(self, value: bool):
-        if value:
-            if "train" not in self._datasets:
-                self._setup_dataset("train")
-            self._train_sampler = _make_balanced_random_sampler(
-                self._datasets["train"],
-            )
-        else:
-            self._train_sampler = None
-
     def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
         """Coherently set the batch-chunk-size after validation.
 
@@ -798,7 +758,7 @@ class ConcatDataModule(lightning.LightningDataModule):
         else:
             self._datasets[name] = _ConcatDataset(datasets)
 
-    def _val_dataset_keys(self) -> list[str]:
+    def val_dataset_keys(self) -> list[str]:
         """Return list of validation dataset names.
 
         Returns
@@ -836,11 +796,11 @@ class ConcatDataModule(lightning.LightningDataModule):
         """
 
         if stage == "fit":
-            for k in ["train"] + self._val_dataset_keys():
+            for k in ["train"] + self.val_dataset_keys():
                 self._setup_dataset(k)
 
         elif stage == "validate":
-            for k in self._val_dataset_keys():
+            for k in self.val_dataset_keys():
                 self._setup_dataset(k)
 
         elif stage == "test":
@@ -929,7 +889,7 @@ class ConcatDataModule(lightning.LightningDataModule):
                 self._datasets[k],
                 **validation_loader_opts,
             )
-            for k in self._val_dataset_keys()
+            for k in self.val_dataset_keys()
         }
 
     def test_dataloader(self) -> dict[str, DataLoader]:
diff --git a/src/mednet/engine/callbacks.py b/src/mednet/engine/callbacks.py
index 4c19ac55e5d325c6506dff393f70f0503e0d5682..b1b16f65ddca03d1d4fdd5a980999fca6aa18349 100644
--- a/src/mednet/engine/callbacks.py
+++ b/src/mednet/engine/callbacks.py
@@ -374,4 +374,5 @@ class LoggingCallback(lightning.pytorch.Callback):
             on_step=False,
             on_epoch=True,
             batch_size=batch[0].shape[0],
+            add_dataloader_idx=False,
         )
diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py
index 23df024bcf5efbc73c2c2b6bc207df3d5889a77e..5ea8ccae4cfecb6987fb7affca1274e6bb474a4e 100644
--- a/src/mednet/engine/trainer.py
+++ b/src/mednet/engine/trainer.py
@@ -72,6 +72,8 @@ def run(
 
     output_folder.mkdir(parents=True, exist_ok=True)
 
+    model.configure_losses()
+
     from .loggers import CustomTensorboardLogger
 
     log_dir = "logs"
diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py
index b4b9e723ad2a327f56e3051522654c3b6166ac89..75223c9a4b78e196d0c81dab20b566c4b5d32d4b 100644
--- a/src/mednet/models/alexnet.py
+++ b/src/mednet/models/alexnet.py
@@ -5,7 +5,6 @@
 import logging
 import typing
 
-import lightning.pytorch as pl
 import torch
 import torch.nn
 import torch.optim.optimizer
@@ -14,36 +13,29 @@ import torchvision.models as models
 import torchvision.transforms
 
 from ..data.typing import TransformSequence
+from .model import Model
 from .separate import separate
 from .transforms import RGB, SquareCenterPad
-from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
 
 
-class Alexnet(pl.LightningModule):
+class Alexnet(Model):
     """Alexnet module.
 
     Note: only usable with a normalized dataset
 
     Parameters
     ----------
-    train_loss
-        The loss to be used during the training.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -60,15 +52,22 @@ class Alexnet(pl.LightningModule):
 
     def __init__(
         self,
-        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
-        validation_loss: torch.nn.Module | None = None,
+        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
         pretrained: bool = False,
         num_classes: int = 1,
     ):
-        super().__init__()
+        super().__init__(
+            loss_type,
+            loss_arguments,
+            optimizer_type,
+            optimizer_arguments,
+            augmentation_transforms,
+            num_classes,
+        )
 
         self.name = "alexnet"
         self.num_classes = num_classes
@@ -79,17 +78,6 @@ class Alexnet(pl.LightningModule):
             RGB(),
         ]
 
-        self._train_loss = train_loss
-        self._validation_loss = (
-            validation_loss if validation_loss is not None else train_loss
-        )
-        self._optimizer_type = optimizer_type
-        self._optimizer_arguments = optimizer_arguments
-
-        self._augmentation_transforms = torchvision.transforms.Compose(
-            augmentation_transforms,
-        )
-
         self.pretrained = pretrained
 
         # Load pretrained model
@@ -109,36 +97,6 @@ class Alexnet(pl.LightningModule):
         x = self.normalizer(x)  # type: ignore
         return self.model_ft(x)
 
-    def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
-        """Perform actions during checkpoint saving (called by lightning).
-
-        Called by Lightning when saving a checkpoint to give you a chance to
-        store anything else you might want to save. Use on_load_checkpoint() to
-        restore what additional data is saved here.
-
-        Parameters
-        ----------
-        checkpoint
-            The checkpoint to save.
-        """
-
-        checkpoint["normalizer"] = self.normalizer
-
-    def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
-        """Perform actions during model loading (called by lightning).
-
-        If you saved something with on_save_checkpoint() this is your chance to
-        restore this.
-
-        Parameters
-        ----------
-        checkpoint
-            The loaded checkpoint.
-        """
-
-        logger.info("Restoring normalizer from checkpoint.")
-        self.normalizer = checkpoint["normalizer"]
-
     def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
         """Initialize the normalizer for the current model.
 
@@ -201,16 +159,9 @@ class Alexnet(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-
         return self._validation_loss(outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         outputs = self(batch[0])
         probabilities = torch.sigmoid(outputs)
         return separate((probabilities, batch[1]))
-
-    def configure_optimizers(self):
-        return self._optimizer_type(
-            self.parameters(),
-            **self._optimizer_arguments,
-        )
diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py
index f7d1544164cf17c6f33f123d6d4bd02435722eb5..76df1ed64a7a73e99b86d18145169dd600601044 100644
--- a/src/mednet/models/densenet.py
+++ b/src/mednet/models/densenet.py
@@ -5,7 +5,6 @@
 import logging
 import typing
 
-import lightning.pytorch as pl
 import torch
 import torch.nn
 import torch.optim.optimizer
@@ -14,34 +13,27 @@ import torchvision.models as models
 import torchvision.transforms
 
 from ..data.typing import TransformSequence
+from .model import Model
 from .separate import separate
 from .transforms import RGB, SquareCenterPad
-from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
 
 
-class Densenet(pl.LightningModule):
+class Densenet(Model):
     """Densenet-121 module.
 
     Parameters
     ----------
-    train_loss
-        The loss to be used during the training.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -60,8 +52,8 @@ class Densenet(pl.LightningModule):
 
     def __init__(
         self,
-        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
-        validation_loss: torch.nn.Module | None = None,
+        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
@@ -69,7 +61,14 @@ class Densenet(pl.LightningModule):
         dropout: float = 0.1,
         num_classes: int = 1,
     ):
-        super().__init__()
+        super().__init__(
+            loss_type,
+            loss_arguments,
+            optimizer_type,
+            optimizer_arguments,
+            augmentation_transforms,
+            num_classes,
+        )
 
         self.name = "densenet-121"
         self.num_classes = num_classes
@@ -80,17 +79,6 @@ class Densenet(pl.LightningModule):
             RGB(),
         ]
 
-        self._train_loss = train_loss
-        self._validation_loss = (
-            validation_loss if validation_loss is not None else train_loss
-        )
-        self._optimizer_type = optimizer_type
-        self._optimizer_arguments = optimizer_arguments
-
-        self._augmentation_transforms = torchvision.transforms.Compose(
-            augmentation_transforms,
-        )
-
         self.pretrained = pretrained
 
         # Load pretrained model
@@ -112,36 +100,6 @@ class Densenet(pl.LightningModule):
         x = self.normalizer(x)  # type: ignore
         return self.model_ft(x)
 
-    def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
-        """Perform actions during checkpoint saving (called by lightning).
-
-        Called by Lightning when saving a checkpoint to give you a chance to
-        store anything else you might want to save. Use on_load_checkpoint() to
-        restore what additional data is saved here.
-
-        Parameters
-        ----------
-        checkpoint
-            The checkpoint to save.
-        """
-
-        checkpoint["normalizer"] = self.normalizer
-
-    def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
-        """Perform actions during model loading (called by lightning).
-
-        If you saved something with on_save_checkpoint() this is your chance to
-        restore this.
-
-        Parameters
-        ----------
-        checkpoint
-            The loaded checkpoint.
-        """
-
-        logger.info("Restoring normalizer from checkpoint.")
-        self.normalizer = checkpoint["normalizer"]
-
     def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
         """Initialize the normalizer for the current model.
 
@@ -205,9 +163,3 @@ class Densenet(pl.LightningModule):
         outputs = self(batch[0])
         probabilities = torch.sigmoid(outputs)
         return separate((probabilities, batch[1]))
-
-    def configure_optimizers(self):
-        return self._optimizer_type(
-            self.parameters(),
-            **self._optimizer_arguments,
-        )
diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py
index bf965790cade10d68e19c0bf372c9fa7bf4d5409..d04bdfea67ea391b7a18a066b6aa2b563a345aea 100644
--- a/src/mednet/models/loss_weights.py
+++ b/src/mednet/models/loss_weights.py
@@ -3,87 +3,180 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 
 import logging
+import typing
+from collections import Counter
 
 import torch
 import torch.utils.data
 
-from ..data.typing import DataLoader
-
 logger = logging.getLogger(__name__)
 
 
-def _get_label_weights(
-    dataloader: torch.utils.data.DataLoader,
-) -> torch.Tensor:
-    """Compute the weights of each class of a DataLoader.
+def compute_binary_weights(targets):
+    """Compute the positive weights when using binary targets.
 
-    This function inputs a pytorch DataLoader and computes the ratio between
-    number of negative and positive samples (scalar).  The weight can be used
-    to adjust minimisation criteria to in cases there is a huge data imbalance.
+    Parameters
+    ----------
+        targets
+            A tensor of integer values of length n.
 
-    It returns a vector with weights (inverse counts) for each label.
+    Returns
+    -------
+        The positive weights per class.
+    """
+    class_sample_count = [
+        float((targets == t).sum().item())
+        for t in torch.unique(targets, sorted=True)
+    ]
+
+    # Divide negatives by positives
+    return torch.tensor(
+        [class_sample_count[0] / class_sample_count[1]],
+    ).reshape(-1)
+
+
+def compute_multiclass_weights(targets):
+    """Compute the positive weights when using exclusive, multiclass targets.
 
     Parameters
     ----------
-    dataloader
-        A DataLoader from which to compute the positive weights.  Entries must
-        be a dictionary which must contain a ``label`` key.
+        targets
+            A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
 
     Returns
     -------
-    torch.Tensor
-        The positive weight of each class in the dataset given as input.
+        The positive weights per class.
     """
 
-    targets = torch.tensor(
-        [sample for batch in dataloader for sample in batch[1]["label"]],
+    class_sample_count = torch.sum(targets, dim=1)
+    negative_class_sample_count = (
+        torch.full((targets.size()[0],), float(targets.size()[1]))
+        - class_sample_count
     )
 
-    # Binary labels
-    if len(list(targets.shape)) == 1:
-        class_sample_count = [
-            float((targets == t).sum().item())
-            for t in torch.unique(targets, sorted=True)
-        ]
+    return negative_class_sample_count / (
+        class_sample_count + negative_class_sample_count
+    )
 
-        # Divide negatives by positives
-        positive_weights = torch.tensor(
-            [class_sample_count[0] / class_sample_count[1]],
-        ).reshape(-1)
 
-    # Multiclass labels
-    else:
-        class_sample_count = torch.sum(targets, dim=0)
-        negative_class_sample_count = (
-            torch.full((targets.size()[1],), float(targets.size()[0]))
-            - class_sample_count
-        )
+def compute_non_exclusive_multiclass_weights(targets):
+    """Compute the positive weights when using non-exclusive, multiclass targets.
 
-        positive_weights = negative_class_sample_count / (
-            class_sample_count + negative_class_sample_count
-        )
+    Parameters
+    ----------
+        targets
+            A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
 
-    return positive_weights
+    Returns
+    -------
+        The positive weights per class.
+    """
+    raise ValueError(
+        "Computing weights of multi-class, non-exclusive labels is not yet supported."
+    )
+
+
+def is_multicalss_exclusive(targets: torch.Tensor) -> bool:
+    """Given a [C x n] tensor of integer targets, checks whether samples can only belong to a single class.
+
+    Parameters
+    ----------
+    targets
+        A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
+
+    Returns
+    -------
+        True if all samples belong to a single class, False otherwise (a sample can belong to multiple classes).
+    """
+    max_counts = []
+    transposed_targets = torch.transpose(targets, 0, 1)
+    for t in transposed_targets:
+        filtered_list = [i for i in t.tolist() if i != 2]
+        counts = Counter(filtered_list)
+        max_counts.append(max(counts.values()))
+
+    if set(max_counts) == {1}:
+        return True
+
+    return False
 
 
-def make_balanced_bcewithlogitsloss(
-    dataloader: DataLoader,
-) -> torch.nn.BCEWithLogitsLoss:
-    """Return a balanced binary-cross-entropy loss.
+def tensor_to_list(tensor) -> list[typing.Any]:
+    """Convert a torch.Tensor to a list.
 
-    The loss is weighted using the ratio between positives and total examples
-    available.
+    This is necessary, as torch.tolist returns an int when then tensor contains a single value.
+
+    Parameters
+    ----------
+    tensor
+        The tensor to convert to a list.
+
+    Returns
+    -------
+        The tensor converted to a list.
+    """
+
+    tensor = tensor.tolist()
+    if isinstance(tensor, int):
+        return [tensor]
+    return tensor
+
+
+def get_positive_weights(
+    dataloader: torch.utils.data.DataLoader,
+) -> torch.Tensor:
+    """Compute the weights of each class of a DataLoader.
+
+    This function inputs a pytorch DataLoader and computes the ratio between
+    number of negative and positive samples (scalar).  The weight can be used
+    to adjust minimisation criteria to in cases there is a huge data imbalance.
+
+    It returns a vector with weights (inverse counts) for each label.
 
     Parameters
     ----------
     dataloader
-        The DataLoader to use to compute the BCE weights.
+        A DataLoader from which to compute the positive weights.  Entries must
+        be a dictionary which must contain a ``label`` key.
 
     Returns
     -------
-    torch.nn.BCEWithLogitsLoss
-        An instance of the weighted loss.
+        The positive weight of each class in the dataset given as input.
     """
 
-    weights = _get_label_weights(dataloader)
-    return torch.nn.BCEWithLogitsLoss(pos_weight=weights)
+    from collections import defaultdict
+
+    targets = defaultdict(list)
+
+    for batch in dataloader:
+        for class_idx, class_targets in enumerate(batch[1]["label"]):
+            # Targets are either a single tensor (binary case) or a list of tensors (multilabel)
+            if isinstance(batch[1]["label"], list):
+                targets[class_idx].extend(tensor_to_list(class_targets))
+            else:
+                targets[0].extend(tensor_to_list(class_targets))
+
+    targets_list = []
+    for k in sorted(list(targets.keys())):
+        targets_list.append(targets[k])
+
+    targets_tensor = torch.tensor(targets_list)
+
+    if targets_tensor.shape[0] == 1:
+        logger.info("Computing positive weights assuming binary labels.")
+        positive_weights = compute_binary_weights(targets_tensor)
+    else:
+        if is_multicalss_exclusive(targets_tensor):
+            logger.info(
+                "Computing positive weights assuming multiclass, exclusive labels."
+            )
+            positive_weights = compute_multiclass_weights(targets_tensor)
+        else:
+            logger.info(
+                "Computing positive weights assuming multiclass, non-exclusive labels."
+            )
+            positive_weights = compute_non_exclusive_multiclass_weights(
+                targets_tensor
+            )
+
+    return positive_weights
diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..57c1b618b1c0943496a4e799552989c03f2e3532
--- /dev/null
+++ b/src/mednet/models/model.py
@@ -0,0 +1,173 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import logging
+import typing
+
+import lightning.pytorch as pl
+import torch
+import torch.nn
+import torch.optim.optimizer
+import torch.utils.data
+import torchvision.transforms
+
+from ..data.typing import TransformSequence
+from .loss_weights import get_positive_weights
+from .typing import Checkpoint
+
+logger = logging.getLogger(__name__)
+
+
+class Model(pl.LightningModule):
+    """Base class for models.
+
+    Parameters
+    ----------
+    loss_type
+        The loss to be used for training and evaluation.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+    loss_arguments
+        Arguments to the loss.
+    optimizer_type
+        The type of optimizer to use for training.
+    optimizer_arguments
+        Arguments to the optimizer after ``params``.
+    augmentation_transforms
+        An optional sequence of torch modules containing transforms to be
+        applied on the input **before** it is fed into the network.
+    num_classes
+        Number of outputs (classes) for this model.
+    """
+
+    def __init__(
+        self,
+        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
+        optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
+        optimizer_arguments: dict[str, typing.Any] = {},
+        augmentation_transforms: TransformSequence = [],
+        num_classes: int = 1,
+    ):
+        super().__init__()
+
+        self.name = "model"
+        self.num_classes = num_classes
+
+        self.model_transforms: TransformSequence = []
+
+        self._loss_type = loss_type
+
+        self._train_loss = None
+        self._train_loss_arguments = loss_arguments
+
+        self.validation_loss = None
+        self._validation_loss_arguments = loss_arguments
+
+        self._optimizer_type = optimizer_type
+        self._optimizer_arguments = optimizer_arguments
+
+        self._augmentation_transforms = torchvision.transforms.Compose(
+            augmentation_transforms,
+        )
+
+    def forward(self, x):
+        raise NotImplementedError
+
+    def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
+        """Perform actions during checkpoint saving (called by lightning).
+
+        Called by Lightning when saving a checkpoint to give you a chance to
+        store anything else you might want to save. Use on_load_checkpoint() to
+        restore what additional data is saved here.
+
+        Parameters
+        ----------
+        checkpoint
+            The checkpoint to save.
+        """
+
+        checkpoint["normalizer"] = self.normalizer
+
+    def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
+        """Perform actions during model loading (called by lightning).
+
+        If you saved something with on_save_checkpoint() this is your chance to
+        restore this.
+
+        Parameters
+        ----------
+        checkpoint
+            The loaded checkpoint.
+        """
+
+        logger.info("Restoring normalizer from checkpoint.")
+        self.normalizer = checkpoint["normalizer"]
+
+    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
+        """Initialize the input normalizer for the current model.
+
+        Parameters
+        ----------
+        dataloader
+            A torch Dataloader from which to compute the mean and std.
+        """
+
+        from .normalizer import make_z_normalizer
+
+        logger.info(
+            f"Uninitialised {self.name} model - "
+            f"computing z-norm factors from train dataloader.",
+        )
+        self.normalizer = make_z_normalizer(dataloader)
+
+    def training_step(self, batch, _):
+        raise NotImplementedError
+
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
+        raise NotImplementedError
+
+    def predict_step(self, batch, batch_idx, dataloader_idx=0):
+        raise NotImplementedError
+
+    def configure_losses(self):
+        self._train_loss = self._loss_type(**self._train_loss_arguments)
+        self._validation_loss = self._loss_type(
+            **self._validation_loss_arguments
+        )
+
+    def configure_optimizers(self):
+        return self._optimizer_type(
+            self.parameters(),
+            **self._optimizer_arguments,
+        )
+
+    def balance_losses(self, datamodule) -> None:
+        """Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute).
+
+        Parameters
+        ----------
+        datamodule
+            Instance of a datamodule.
+        """
+
+        try:
+            getattr(self._loss_type(), "pos_weight")
+        except AttributeError:
+            logger.warning(
+                f"Loss {self._loss_type} does not posess a 'pos_weight' attribute and will not be balanced."
+            )
+        else:
+            logger.info(f"Balancing training loss {self._loss_type}.")
+            train_weights = get_positive_weights(datamodule.train_dataloader())
+            self._train_loss_arguments["pos_weight"] = train_weights
+
+            logger.info(f"Balancing validation loss {self._loss_type}.")
+            validation_weights = get_positive_weights(
+                datamodule.val_dataloader()["validation"]
+            )
+            self._validation_loss_arguments["pos_weight"] = validation_weights
diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py
index 16a71f73c93bc9ab4dd8a65d2c9da7d96d27a36e..e9147683b08f8d8b532396544eece7829c23fd92 100644
--- a/src/mednet/models/pasa.py
+++ b/src/mednet/models/pasa.py
@@ -5,7 +5,6 @@
 import logging
 import typing
 
-import lightning.pytorch as pl
 import torch
 import torch.nn
 import torch.nn.functional as F  # noqa: N812
@@ -14,14 +13,14 @@ import torch.utils.data
 import torchvision.transforms
 
 from ..data.typing import TransformSequence
+from .model import Model
 from .separate import separate
 from .transforms import Grayscale, SquareCenterPad
-from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
 
 
-class Pasa(pl.LightningModule):
+class Pasa(Model):
     """Implementation of CNN by Pasa and others.
 
     Simple CNN for classification based on paper by [PASA-2019]_.
@@ -31,22 +30,15 @@ class Pasa(pl.LightningModule):
 
     Parameters
     ----------
-    train_loss
-        The loss to be used during the training.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -60,14 +52,21 @@ class Pasa(pl.LightningModule):
 
     def __init__(
         self,
-        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
-        validation_loss: torch.nn.Module | None = None,
+        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
     ):
-        super().__init__()
+        super().__init__(
+            loss_type,
+            loss_arguments,
+            optimizer_type,
+            optimizer_arguments,
+            augmentation_transforms,
+            num_classes,
+        )
 
         self.name = "pasa"
         self.num_classes = num_classes
@@ -82,17 +81,6 @@ class Pasa(pl.LightningModule):
             ),
         ]
 
-        self._train_loss = train_loss
-        self._validation_loss = (
-            validation_loss if validation_loss is not None else train_loss
-        )
-        self._optimizer_type = optimizer_type
-        self._optimizer_arguments = optimizer_arguments
-
-        self._augmentation_transforms = torchvision.transforms.Compose(
-            augmentation_transforms,
-        )
-
         # First convolution block
         self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
         self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
@@ -213,53 +201,6 @@ class Pasa(pl.LightningModule):
 
         # x = F.log_softmax(x, dim=1) # 0 is batch size
 
-    def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
-        """Perform actions during checkpoint saving (called by lightning).
-
-        Called by Lightning when saving a checkpoint to give you a chance to
-        store anything else you might want to save. Use on_load_checkpoint() to
-        restore what additional data is saved here.
-
-        Parameters
-        ----------
-        checkpoint
-            The checkpoint to save.
-        """
-
-        checkpoint["normalizer"] = self.normalizer
-
-    def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
-        """Perform actions during model loading (called by lightning).
-
-        If you saved something with on_save_checkpoint() this is your chance to
-        restore this.
-
-        Parameters
-        ----------
-        checkpoint
-            The loaded checkpoint.
-        """
-
-        logger.info("Restoring normalizer from checkpoint.")
-        self.normalizer = checkpoint["normalizer"]
-
-    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
-        """Initialize the input normalizer for the current model.
-
-        Parameters
-        ----------
-        dataloader
-            A torch Dataloader from which to compute the mean and std.
-        """
-
-        from .normalizer import make_z_normalizer
-
-        logger.info(
-            f"Uninitialised {self.name} model - "
-            f"computing z-norm factors from train dataloader.",
-        )
-        self.normalizer = make_z_normalizer(dataloader)
-
     def training_step(self, batch, _):
         images = batch[0]
         labels = batch[1]["label"]
@@ -285,16 +226,9 @@ class Pasa(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-
         return self._validation_loss(outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         outputs = self(batch[0])
         probabilities = torch.sigmoid(outputs)
         return separate((probabilities, batch[1]))
-
-    def configure_optimizers(self):
-        return self._optimizer_type(
-            self.parameters(),
-            **self._optimizer_arguments,
-        )
diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py
index 68a4e7e721f412fbed430774491147f9c1130577..e83ae359b0be523565f669c57bf1de41520ccac7 100644
--- a/src/mednet/scripts/train.py
+++ b/src/mednet/scripts/train.py
@@ -296,10 +296,8 @@ def train(
     # of class examples available in the training set.  Also affects the
     # validation loss if a validation set is available on the DataModule.
     if balance_classes:
-        logger.info("Applying DataModule train sampler balancing...")
-        datamodule.balance_sampler_by_class = True
-        # logger.info("Applying train/valid loss balancing...")
-        # model.balance_losses_by_class(datamodule)
+        logger.info("Applying train/valid loss balancing...")
+        model.balance_losses(datamodule)
     else:
         logger.info(
             "Skipping sample class/dataset ownership balancing on user request",
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 4ee0e2c6ab454161dbf871030b30b989593d1c85..a512d6add392dc09439daaca7e0d8f9b01fbc443 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -241,8 +241,7 @@ def test_train_pasa_montgomery(temporary_basedir):
         keywords = {
             r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
             r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
-            r"^Applying DataModule train sampler balancing...$": 1,
-            r"^Balancing samples from dataset using metadata targets `label`$": 1,
+            r"^Applying train/valid loss balancing...$": 1,
             r"^Training for at most 1 epochs.$": 1,
             r"^Uninitialised pasa model - computing z-norm factors from train dataloader.$": 1,
             r"^Writing run metadata at.*$": 1,
@@ -323,8 +322,7 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
         keywords = {
             r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
             r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
-            r"^Applying DataModule train sampler balancing...$": 1,
-            r"^Balancing samples from dataset using metadata targets `label`$": 1,
+            r"^Applying train/valid loss balancing...$": 1,
             r"^Training for at most 2 epochs.$": 1,
             r"^Resuming from epoch 0 \(checkpoint file: .*$": 1,
             r"^Writing run metadata at.*$": 1,