diff --git a/src/mednet/libs/classification/engine/saliency/generator.py b/src/mednet/libs/classification/engine/saliency/generator.py
index 28a11674d027f0a92ba0bb9c5a6096f475899214..2b9badf19390c4e564041dec0a28570ad853da8e 100644
--- a/src/mednet/libs/classification/engine/saliency/generator.py
+++ b/src/mednet/libs/classification/engine/saliency/generator.py
@@ -212,14 +212,14 @@ def run(
 
         for sample in tqdm.tqdm(v, desc="samples", leave=False, disable=None):
             name = sample[1]["name"][0]
-            label = sample[1]["label"].item()
+            target = sample[1]["target"].item()
             image = sample[0].to(
                 device=device,
                 non_blocking=torch.cuda.is_available(),
             )
 
-            # in binary classification systems, negative labels may be skipped
-            if positive_only and (model.num_classes == 1) and (label == 0):
+            # in binary classification systems, negative targets may be skipped
+            if positive_only and (model.num_classes == 1) and (target == 0):
                 continue
 
             # chooses target outputs to generate saliency maps for
diff --git a/src/mednet/libs/classification/engine/saliency/viewer.py b/src/mednet/libs/classification/engine/saliency/viewer.py
index 588ba70ccf4584db6026de5cb88423321c2184f8..063ce7379c7a777d2cd28d8f28f4d7b1f118ccb2 100644
--- a/src/mednet/libs/classification/engine/saliency/viewer.py
+++ b/src/mednet/libs/classification/engine/saliency/viewer.py
@@ -233,7 +233,7 @@ def run(
             # WARNING: following code assumes a batch size of 1. Will break if
             # not the case.
             name = str(sample[1]["name"][0])
-            label = int(sample[1]["label"].item())
+            label = int(sample[1]["target"].item())
             data = sample[0][0]
 
             if label != target_label:
diff --git a/src/mednet/libs/classification/scripts/experiment.py b/src/mednet/libs/classification/scripts/experiment.py
index 3118b14b51332d79faafd31755856e50bac47ae3..814f3b76da7e1e4d2c9b95c51c17c97879930e6f 100644
--- a/src/mednet/libs/classification/scripts/experiment.py
+++ b/src/mednet/libs/classification/scripts/experiment.py
@@ -48,7 +48,6 @@ def experiment(
     seed,
     parallel,
     monitoring_interval,
-    balance_classes,
     augmentations,
     **_,
 ):  # numpydoc ignore=PR01
@@ -89,7 +88,6 @@ def experiment(
         seed=seed,
         parallel=parallel,
         monitoring_interval=monitoring_interval,
-        balance_classes=balance_classes,
         augmentations=augmentations,
     )
     train_stop_timestamp = datetime.now()
diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py
index 6e1e28e8d560e634d80c1aba18a79c0a137fbfa5..85a570d28d0619c2f170d640129434ae26a4d250 100644
--- a/src/mednet/libs/classification/scripts/train.py
+++ b/src/mednet/libs/classification/scripts/train.py
@@ -86,14 +86,8 @@ def train(
     # of class examples available in the training set.  Also affects the
     # validation loss if a validation set is available on the DataModule.
     if balance_classes:
-        logger.info("Applying DataModule train sampler balancing...")
-        datamodule.balance_sampler_by_class = True
-        # logger.info("Applying train/valid loss balancing...")
-        # model.balance_losses_by_class(datamodule)
-    else:
-        logger.info(
-            "Skipping sample class/dataset ownership balancing on user request",
-        )
+        logger.info("Applying train/valid loss balancing...")
+        model.balance_losses(datamodule)
 
     checkpoint_file = get_checkpoint_file(output_folder)
     load_checkpoint(checkpoint_file, datamodule, model)
diff --git a/src/mednet/libs/common/models/loss_weights.py b/src/mednet/libs/common/models/loss_weights.py
index b3c2edec5b982762521eedbc4778ea21abb4d5fc..596c9e7b75babb85e5d15a714a1e3d1de70aa3db 100644
--- a/src/mednet/libs/common/models/loss_weights.py
+++ b/src/mednet/libs/common/models/loss_weights.py
@@ -129,13 +129,13 @@ def get_positive_weights(
     number of negative and positive samples (scalar).  The weight can be used
     to adjust minimisation criteria to in cases there is a huge data imbalance.
 
-    It returns a vector with weights (inverse counts) for each label.
+    It returns a vector with weights (inverse counts) for each target.
 
     Parameters
     ----------
     dataloader
         A DataLoader from which to compute the positive weights.  Entries must
-        be a dictionary which must contain a ``label`` key.
+        be a dictionary which must contain a ``target`` key.
 
     Returns
     -------
@@ -147,9 +147,9 @@ def get_positive_weights(
     targets = defaultdict(list)
 
     for batch in dataloader:
-        for class_idx, class_targets in enumerate(batch[1]["label"]):
+        for class_idx, class_targets in enumerate(batch[1]["target"]):
             # Targets are either a single tensor (binary case) or a list of tensors (multilabel)
-            if isinstance(batch[1]["label"], list):
+            if isinstance(batch[1]["target"], list):
                 targets[class_idx].extend(tensor_to_list(class_targets))
             else:
                 targets[0].extend(tensor_to_list(class_targets))
@@ -161,17 +161,17 @@ def get_positive_weights(
     targets_tensor = torch.tensor(targets_list)
 
     if targets_tensor.shape[0] == 1:
-        logger.info("Computing positive weights assuming binary labels.")
+        logger.info("Computing positive weights assuming binary targets.")
         positive_weights = compute_binary_weights(targets_tensor)
     else:
         if is_multicalss_exclusive(targets_tensor):
             logger.info(
-                "Computing positive weights assuming multiclass, exclusive labels."
+                "Computing positive weights assuming multiclass, exclusive targets."
             )
             positive_weights = compute_multiclass_weights(targets_tensor)
         else:
             logger.info(
-                "Computing positive weights assuming multiclass, non-exclusive labels."
+                "Computing positive weights assuming multiclass, non-exclusive targets."
             )
             positive_weights = compute_non_exclusive_multiclass_weights(targets_tensor)
 
diff --git a/src/mednet/libs/common/models/normalizer.py b/src/mednet/libs/common/models/normalizer.py
index 63a5f10c67f369a34e1f7ad04ef92b3ee649b2fb..603de9c132d0a51a1c2cd56958574fdc874bb56f 100644
--- a/src/mednet/libs/common/models/normalizer.py
+++ b/src/mednet/libs/common/models/normalizer.py
@@ -38,7 +38,7 @@ def make_z_normalizer(
 
     # Ensure targets are ints
     try:
-        target = batch[1]["label"][0].item()
+        target = batch[1]["target"][0].item()
         if not isinstance(target, int):
             logger.info(
                 "Targets are not Integer type, skipping z-normalization."