diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py
index c73bb27932c52e2061134a7f007256c7f6292161..2de54fdf96d95ac4f553c43b548d168dcf2bace4 100644
--- a/src/mednet/data/datamodule.py
+++ b/src/mednet/data/datamodule.py
@@ -481,10 +481,6 @@ class ConcatDataModule(lightning.LightningDataModule):
         for CPU memory.  Sufficient CPU memory must be available before you set
         this attribute to ``True``.  It is typically useful for relatively small
         datasets.
-    balance_sampler_by_class
-        If set, then modifies the random sampler used during training and
-        validation to balance sample picking probability, making sample
-        across classes **and** datasets equitable.
     batch_size
         Number of samples in every **training** batch (this parameter affects
         memory requirements for the network).  If the number of samples in the
@@ -529,7 +525,6 @@ class ConcatDataModule(lightning.LightningDataModule):
         database_name: str = "",
         split_name: str = "",
         cache_samples: bool = False,
-        balance_sampler_by_class: bool = False,
         batch_size: int = 1,
         batch_chunk_count: int = 1,
         drop_incomplete_batch: bool = False,
@@ -552,7 +547,6 @@ class ConcatDataModule(lightning.LightningDataModule):
 
         self.cache_samples = cache_samples
         self._train_sampler = None
-        self.balance_sampler_by_class = balance_sampler_by_class
 
         self._model_transforms: list[Transform] | None = None
 
@@ -667,40 +661,6 @@ class ConcatDataModule(lightning.LightningDataModule):
             )
             self._datasets = {}
 
-    @property
-    def balance_sampler_by_class(self) -> bool:
-        """Whether to balance samples across labels/datasets.
-
-        If set, then modifies the random sampler used during training
-        and validation to balance sample picking probability, making
-        sample across classes **and** datasets equitable.
-
-        .. warning::
-
-           This method does **NOT** balance the sampler per dataset, in case
-           multiple datasets compose the same training set. It only balances
-           samples acording to their ground-truth (labels).  If you'd like to
-           have samples balanced per dataset, then implement your own data
-           module inheriting from this one.
-
-        Returns
-        -------
-        bool
-            True if self._train_sample is set, else False.
-        """
-        return self._train_sampler is not None
-
-    @balance_sampler_by_class.setter
-    def balance_sampler_by_class(self, value: bool):
-        if value:
-            if "train" not in self._datasets:
-                self._setup_dataset("train")
-            self._train_sampler = _make_balanced_random_sampler(
-                self._datasets["train"],
-            )
-        else:
-            self._train_sampler = None
-
     def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
         """Coherently set the batch-chunk-size after validation.
 
diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py
index bf965790cade10d68e19c0bf372c9fa7bf4d5409..32f1e33e459b0f7b08a9b4f13eb72e05ee65492d 100644
--- a/src/mednet/models/loss_weights.py
+++ b/src/mednet/models/loss_weights.py
@@ -34,7 +34,6 @@ def _get_label_weights(
     torch.Tensor
         The positive weight of each class in the dataset given as input.
     """
-
     targets = torch.tensor(
         [sample for batch in dataloader for sample in batch[1]["label"]],
     )
diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py
index 50e314bb3380246f738a44b720b851ba97d9502c..a0b3701eae4376481a3875e5c53a91ed6976fe05 100644
--- a/src/mednet/models/model.py
+++ b/src/mednet/models/model.py
@@ -13,6 +13,7 @@ import torch.utils.data
 import torchvision.transforms
 
 from ..data.typing import TransformSequence
+from .loss_weights import _get_label_weights
 from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
@@ -141,3 +142,37 @@ class Model(pl.LightningModule):
             self.parameters(),
             **self._optimizer_arguments,
         )
+
+    def balance_losses(self, datamodule) -> None:
+        """Balance the loss based on the distribution of targets in the datamodule, if the loss function supports it.
+
+        Parameters
+        ----------
+        datamodule
+            Instance of a datamodule.
+        """
+        logger.info(f"Balancing training loss function {self._train_loss}.")
+        try:
+            getattr(self._train_loss, "pos_weight")
+        except AttributeError:
+            logger.warning(
+                "Training loss does not posess a 'pos_weight' attribute and will not be balanced."
+            )
+        else:
+            train_weights = _get_label_weights(datamodule.train_dataloader())
+            setattr(self._train_loss, "pos_weight", train_weights)
+
+        logger.info(
+            f"Balancing validation loss function {self._validation_loss}."
+        )
+        try:
+            getattr(self._validation_loss, "pos_weight")
+        except AttributeError:
+            logger.warning(
+                "Validation loss does not posess a 'pos_weight' attribute and will not be balanced."
+            )
+        else:
+            validation_weights = _get_label_weights(
+                datamodule.val_dataloader()["validation"]
+            )
+            setattr(self._validation_loss, "pos_weight", validation_weights)
diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py
index 68a4e7e721f412fbed430774491147f9c1130577..e83ae359b0be523565f669c57bf1de41520ccac7 100644
--- a/src/mednet/scripts/train.py
+++ b/src/mednet/scripts/train.py
@@ -296,10 +296,8 @@ def train(
     # of class examples available in the training set.  Also affects the
     # validation loss if a validation set is available on the DataModule.
     if balance_classes:
-        logger.info("Applying DataModule train sampler balancing...")
-        datamodule.balance_sampler_by_class = True
-        # logger.info("Applying train/valid loss balancing...")
-        # model.balance_losses_by_class(datamodule)
+        logger.info("Applying train/valid loss balancing...")
+        model.balance_losses(datamodule)
     else:
         logger.info(
             "Skipping sample class/dataset ownership balancing on user request",