diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py
index 32f1e33e459b0f7b08a9b4f13eb72e05ee65492d..f91e1eda1f563f5e1f53922576605704f4071e64 100644
--- a/src/mednet/models/loss_weights.py
+++ b/src/mednet/models/loss_weights.py
@@ -3,86 +3,180 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 
 import logging
+import typing
+from collections import Counter
 
 import torch
 import torch.utils.data
 
-from ..data.typing import DataLoader
-
 logger = logging.getLogger(__name__)
 
 
-def _get_label_weights(
-    dataloader: torch.utils.data.DataLoader,
-) -> torch.Tensor:
-    """Compute the weights of each class of a DataLoader.
+def compute_binary_weights(targets):
+    """Compute the positive weights when using binary targets.
 
-    This function inputs a pytorch DataLoader and computes the ratio between
-    number of negative and positive samples (scalar).  The weight can be used
-    to adjust minimisation criteria to in cases there is a huge data imbalance.
+    Parameters
+    ----------
+        targets
+            A tensor of integer values of length n.
 
-    It returns a vector with weights (inverse counts) for each label.
+    Returns
+    -------
+        The positive weights per class.
+    """
+    class_sample_count = [
+        float((targets == t).sum().item())
+        for t in torch.unique(targets, sorted=True)
+    ]
+
+    # Divide negatives by positives
+    return torch.tensor(
+        [class_sample_count[0] / class_sample_count[1]],
+    ).reshape(-1)
+
+
+def compute_multiclass_weights(targets):
+    """Compute the positive weights when using exclusive, multiclass targets.
 
     Parameters
     ----------
-    dataloader
-        A DataLoader from which to compute the positive weights.  Entries must
-        be a dictionary which must contain a ``label`` key.
+        targets
+            A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
 
     Returns
     -------
-    torch.Tensor
-        The positive weight of each class in the dataset given as input.
+        The positive weights per class.
     """
-    targets = torch.tensor(
-        [sample for batch in dataloader for sample in batch[1]["label"]],
+
+    class_sample_count = torch.sum(targets, dim=1)
+    negative_class_sample_count = (
+        torch.full((targets.size()[0],), float(targets.size()[1]))
+        - class_sample_count
     )
 
-    # Binary labels
-    if len(list(targets.shape)) == 1:
-        class_sample_count = [
-            float((targets == t).sum().item())
-            for t in torch.unique(targets, sorted=True)
-        ]
+    return negative_class_sample_count / (
+        class_sample_count + negative_class_sample_count
+    )
 
-        # Divide negatives by positives
-        positive_weights = torch.tensor(
-            [class_sample_count[0] / class_sample_count[1]],
-        ).reshape(-1)
 
-    # Multiclass labels
-    else:
-        class_sample_count = torch.sum(targets, dim=0)
-        negative_class_sample_count = (
-            torch.full((targets.size()[1],), float(targets.size()[0]))
-            - class_sample_count
-        )
+def compute_non_exclusive_multiclass_weights(targets):
+    """Compute the positive weights when using non-exclusive, multiclass targets.
 
-        positive_weights = negative_class_sample_count / (
-            class_sample_count + negative_class_sample_count
-        )
+    Parameters
+    ----------
+        targets
+            A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
 
-    return positive_weights
+    Returns
+    -------
+        The positive weights per class.
+    """
+    raise ValueError(
+        "Computing weights of multi-class, non-exclusive labels is not yet supported."
+    )
 
 
-def make_balanced_bcewithlogitsloss(
-    dataloader: DataLoader,
-) -> torch.nn.BCEWithLogitsLoss:
-    """Return a balanced binary-cross-entropy loss.
+def is_multicalss_exclusive(targets: torch.Tensor) -> bool:
+    """Given a [C x n] tensor of integer targets, checks whether samples can only belong to a single class.
 
-    The loss is weighted using the ratio between positives and total examples
-    available.
+    Parameters
+    ----------
+    targets
+        A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
+
+    Returns
+    -------
+        True if all samples belong to a single class, False otherwise (a sample can belong to multiple classes).
+    """
+    max_counts = []
+    transposed_targets = torch.transpose(targets, 0, 1)
+    for t in transposed_targets:
+        filtered_list = [i for i in t.tolist() if i != 2]
+        counts = Counter(filtered_list)
+        max_counts.append(max(counts.values()))
+
+    if set(max_counts) == {1}:
+        return True
+
+    return False
+
+
+def tensor_to_list(tensor) -> list[typing.Any]:
+    """Convert a torch.Tensor to a list.
+
+    This is necessary, as torch.tolist returns an int when then tensor contains a single value.
+
+    Parameters
+    ----------
+    tensor
+        The tensor to convert to a list.
+
+    Returns
+    -------
+        The tensor converted to a list.
+    """
+
+    tensor = tensor.tolist()
+    if isinstance(tensor, int):
+        return [tensor]
+    return tensor
+
+
+def get_positive_weights(
+    dataloader: torch.utils.data.DataLoader,
+) -> torch.Tensor:
+    """Compute the weights of each class of a DataLoader.
+
+    This function inputs a pytorch DataLoader and computes the ratio between
+    number of negative and positive samples (scalar).  The weight can be used
+    to adjust minimisation criteria to in cases there is a huge data imbalance.
+
+    It returns a vector with weights (inverse counts) for each label.
 
     Parameters
     ----------
     dataloader
-        The DataLoader to use to compute the BCE weights.
+        A DataLoader from which to compute the positive weights.  Entries must
+        be a dictionary which must contain a ``label`` key.
 
     Returns
     -------
-    torch.nn.BCEWithLogitsLoss
-        An instance of the weighted loss.
+        The positive weight of each class in the dataset given as input.
     """
 
-    weights = _get_label_weights(dataloader)
-    return torch.nn.BCEWithLogitsLoss(pos_weight=weights)
+    from collections import defaultdict
+
+    targets = defaultdict(list)
+
+    for batch in dataloader:
+        for class_idx, class_targets in enumerate(batch[1]["label"]):
+            # Targets are either a single tensor (binary case) or a list of tensors (multilabel)
+            if isinstance(batch[1]["label"], list):
+                targets[class_idx].extend(tensor_to_list(class_targets))
+            else:
+                targets[0].extend(tensor_to_list(class_targets))
+
+    targets_list = []
+    for k in sorted(list(targets.keys())):
+        targets_list.append(targets[k])
+
+    targets_tensor = torch.tensor(targets_list)
+
+    if len(list(targets_tensor.shape)) == 1:
+        logger.info("Computing positive weights assuming binary labels.")
+        positive_weights = compute_binary_weights(targets_tensor)
+    else:
+        if is_multicalss_exclusive(targets_tensor):
+            logger.info(
+                "Computing positive weights assuming multiclass, exclusive labels."
+            )
+            positive_weights = compute_multiclass_weights(targets_tensor)
+        else:
+            logger.info(
+                "Computing positive weights assuming multiclass, non-exclusive labels."
+            )
+            positive_weights = compute_non_exclusive_multiclass_weights(
+                targets_tensor
+            )
+
+    return positive_weights
diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py
index 155ff19c0c0887b9a0504b414eb93242b7b0ff63..e5a13dc8336aa28335bba58d26e98d0814ac9ed8 100644
--- a/src/mednet/models/model.py
+++ b/src/mednet/models/model.py
@@ -13,7 +13,7 @@ import torch.utils.data
 import torchvision.transforms
 
 from ..data.typing import TransformSequence
-from .loss_weights import _get_label_weights
+from .loss_weights import get_positive_weights
 from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
@@ -152,6 +152,7 @@ class Model(pl.LightningModule):
         datamodule
             Instance of a datamodule.
         """
+
         logger.info(f"Balancing training loss function {self._train_loss}.")
         try:
             getattr(self._train_loss, "pos_weight")
@@ -160,7 +161,7 @@ class Model(pl.LightningModule):
                 "Training loss does not posess a 'pos_weight' attribute and will not be balanced."
             )
         else:
-            train_weights = _get_label_weights(datamodule.train_dataloader())
+            train_weights = get_positive_weights(datamodule.train_dataloader())
             setattr(self._train_loss, "pos_weight", train_weights)
 
         logger.info(
@@ -185,7 +186,7 @@ class Model(pl.LightningModule):
             )
 
             for val_dataset_key in datamodule_validation_keys:
-                validation_weights = _get_label_weights(
+                validation_weights = get_positive_weights(
                     datamodule.val_dataloader()[val_dataset_key]
                 )
                 new_validation_losses.append(