diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py
index 47324199ff6daf3475d96e5999063621493b985b..49ee76dbd002a9fbd13e99d052eef223c17a19ac 100644
--- a/src/ptbench/configs/models/pasa.py
+++ b/src/ptbench/configs/models/pasa.py
@@ -13,22 +13,25 @@ Reference: [PASA-2019]_
 
 from torch import empty
 from torch.nn import BCEWithLogitsLoss
+from torch.optim import Adam
 
-from ...data.transforms import ElasticDeformation
 from ...models.pasa import PASA
 
-# config
-optimizer_configs = {"lr": 8e-5}
-
 # optimizer
-optimizer = "Adam"
+optimizer = Adam
+optimizer_configs = {"lr": 8e-5}
 
 # criterion
 criterion = BCEWithLogitsLoss(pos_weight=empty(1))
 criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
 
+from ...data.transforms import ElasticDeformation
+
 augmentation_transforms = [ElasticDeformation(p=0.8)]
 
+# from torchvision.transforms.v2 import ElasticTransform, InterpolationMode
+# augmentation_transforms = [ElasticTransform(alpha=1000.0, sigma=30.0, interpolation=InterpolationMode.NEAREST)]
+
 # model
 model = PASA(
     criterion,
diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 8c297823fb649d6f750c2799416504b335b24f60..fc9883d2f89c008fe7fb6c52ed62dea8664df17b 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -13,6 +13,8 @@ import torch
 import torch.utils.data
 import torchvision.transforms
 
+from tqdm import tqdm
+
 logger = logging.getLogger(__name__)
 
 
@@ -150,8 +152,13 @@ class _CachedDataset(torch.utils.data.Dataset):
             typing.Callable[[torch.Tensor], torch.Tensor]
         ] = [],
     ):
-        self.transform = torchvision.transforms.Compose(*transforms)
-        self.data = [raw_data_loader(k) for k in split]
+        # Cannot unpack empty list
+        if len(transforms) > 0:
+            self.transform = torchvision.transforms.Compose([*transforms])
+        else:
+            self.transform = torchvision.transforms.Compose([])
+
+        self.data = [raw_data_loader(k) for k in tqdm(split)]
 
     def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
         tensor, metadata = self.data[key]
@@ -446,6 +453,7 @@ class CachingDataModule(lightning.LightningDataModule):
             logger.info(f"Dataset {name} is already setup.  Not reloading it.")
             return
         if self.cache_samples:
+            logger.info(f"Caching {name} dataset")
             self._datasets[name] = _CachedDataset(
                 self.database_split[name],
                 self.raw_data_loader,
diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py
index af8a736454608961535caee8cddb432a42199410..3529ee7f0a14c0fc03a01f9d02348cf09a43cb20 100644
--- a/src/ptbench/data/dataset.py
+++ b/src/ptbench/data/dataset.py
@@ -10,11 +10,11 @@ import torch.utils.data
 logger = logging.getLogger(__name__)
 
 
-def _get_positive_weights(dataset):
+def _get_positive_weights(dataloader):
     """Compute the positive weights of each class of the dataset to balance the
     BCEWithLogitsLoss criterion.
 
-    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
+    This function takes as input a :py:class:`torch.utils.data.DataLoader`
     and computes the positive weights of each class to use them to have
     a balanced loss.
 
@@ -22,9 +22,8 @@ def _get_positive_weights(dataset):
     Parameters
     ----------
 
-    dataset : torch.utils.data.dataset.Dataset
-        An instance of torch.utils.data.dataset.Dataset
-        ConcatDataset are supported
+    dataloader : :py:class:`torch.utils.data.DataLoader`
+        A DataLoader from which to compute the positive weights. Must contain a 'label' key in the metadata returned by __getitem__().
 
 
     Returns
@@ -35,14 +34,8 @@ def _get_positive_weights(dataset):
     """
     targets = []
 
-    if isinstance(dataset, torch.utils.data.ConcatDataset):
-        for ds in dataset.datasets:
-            for s in ds._samples:
-                targets.append(s["label"])
-
-    else:
-        for s in dataset._samples:
-            targets.append(s["label"])
+    for batch in dataloader:
+        targets.extend(batch[1]["label"])
 
     targets = torch.tensor(targets)
 
@@ -71,33 +64,3 @@ def _get_positive_weights(dataset):
         )
 
     return positive_weights
-
-
-def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
-    from torch.nn import BCEWithLogitsLoss
-
-    datamodule.prepare_data()
-    datamodule.setup(stage="fit")
-
-    train_dataset = datamodule.train_dataset
-    validation_dataset = datamodule.validation_dataset
-
-    # Redefine a weighted criterion if possible
-    if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
-        positive_weights = _get_positive_weights(train_dataset)
-        model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
-    else:
-        logger.warning("Weighted criterion not supported")
-
-    if validation_dataset is not None:
-        # Redefine a weighted valid criterion if possible
-        if (
-            isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
-            or criterion_valid is None
-        ):
-            positive_weights = _get_positive_weights(validation_dataset)
-            model.hparams.criterion_valid = BCEWithLogitsLoss(
-                pos_weight=positive_weights
-            )
-        else:
-            logger.warning("Weighted valid criterion not supported")
diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index d0ac43f98e21b8ce6803797d6a1fde38c6302660..580cc26fd24efc6ebb98fbcbb02d51348c6c11ec 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -94,9 +94,7 @@ class LoggingCallback(Callback):
             self.log("total_time", current_time)
             self.log("eta", eta_seconds)
             self.log("loss", numpy.average(self.training_loss))
-            self.log(
-                "learning_rate", pl_module.hparams["optimizer_configs"]["lr"]
-            )
+            self.log("learning_rate", pl_module.optimizer_configs["lr"])
             self.log("validation_loss", numpy.sum(self.validation_loss))
 
             if len(self.extra_validation_loss) > 0:
diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index 6b156f86822209445a709c0af99b8b1eb14c55dc..ecf29153bc56c84185a905edc97867b0730a0a57 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -51,20 +51,24 @@ def save_model_summary(
 
     Returns
     -------
-    r
+    summary:
         The model summary in a text format.
 
-    n
+    total_parameters:
         The number of parameters of the model.
     """
     summary_path = os.path.join(output_folder, "model_summary.txt")
     logger.info(f"Saving model summary at {summary_path}...")
     with open(summary_path, "w") as f:
-        summary = lightning.pytorch.callbacks.ModelSummary(model, max_depth=-1)
+        summary = lightning.pytorch.utilities.model_summary.ModelSummary(
+            model, max_depth=-1
+        )
         f.write(str(summary))
     return (
         summary,
-        lightning.pytorch.callbacks.ModelSummary(model).total_parameters,
+        lightning.pytorch.utilities.model_summary.ModelSummary(
+            model
+        ).total_parameters,
     )
 
 
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 76327670d047900f10b8a94d90a7bc95a44fe2a0..e93b4e61f3bd5c9a0259e7d21728362f8e26999a 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -9,6 +9,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.utils.data
+import torchvision.transforms
 
 logger = logging.getLogger(__name__)
 
@@ -31,7 +32,15 @@ class PASA(pl.LightningModule):
 
         self.name = "pasa"
 
-        self.augmentation_transforms = augmentation_transforms
+        self.augmentation_transforms = torchvision.transforms.Compose(
+            augmentation_transforms
+        )
+
+        self.criterion = criterion
+        self.criterion_valid = criterion_valid
+
+        self.optimizer = optimizer
+        self.optimizer_configs = optimizer_configs
 
         self.normalizer = None
 
@@ -137,7 +146,7 @@ class PASA(pl.LightningModule):
         Parameters
         ----------
 
-        dataloader:
+        dataloader: :py:class:`torch.utils.data.DataLoader`
             A torch Dataloader from which to compute the mean and std
         """
         from .normalizer import make_z_normalizer
@@ -148,6 +157,35 @@ class PASA(pl.LightningModule):
         )
         self.normalizer = make_z_normalizer(dataloader)
 
+    def set_bce_loss_weights(self, datamodule):
+        """Reweights loss weights if BCEWithLogitsLoss is used.
+
+        Parameters
+        ----------
+
+        datamodule:
+            A datamodule implementing train_dataloader() and val_dataloader()
+        """
+        from ..data.dataset import _get_positive_weights
+
+        if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss training criterion.")
+            train_positive_weights = _get_positive_weights(
+                datamodule.train_dataloader()
+            )
+            self.criterion = torch.nn.BCEWithLogitsLoss(
+                pos_weight=train_positive_weights
+            )
+
+        if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
+            validation_positive_weights = _get_positive_weights(
+                datamodule.val_dataloader()["validation"]
+            )
+            self.criterion_valid = torch.nn.BCEWithLogitsLoss(
+                pos_weight=validation_positive_weights
+            )
+
     def training_step(self, batch, _):
         images = batch[0]
         labels = batch[1]["label"]
@@ -158,12 +196,11 @@ class PASA(pl.LightningModule):
             labels = torch.reshape(labels, (labels.shape[0], 1))
 
         # Forward pass on the network
-        augmented_images = self.augmentation_transforms(images)
+        augmented_images = [self.augmentation_transforms(img) for img in images]
+        augmented_images = torch.unsqueeze(torch.cat(augmented_images, 0), 1)
         outputs = self(augmented_images)
 
-        # Manually move criterion to selected device, since not part of the model.
-        self.hparams.criterion = self.hparams.criterion.to(self.device)
-        training_loss = self.hparams.criterion(outputs, labels.double())
+        training_loss = self.criterion(outputs, labels.double())
 
         return {"loss": training_loss}
 
@@ -179,11 +216,7 @@ class PASA(pl.LightningModule):
         # data forwarding on the existing network
         outputs = self(images)
 
-        # Manually move criterion to selected device, since not part of the model.
-        self.hparams.criterion_valid = self.hparams.criterion_valid.to(
-            self.device
-        )
-        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
+        validation_loss = self.criterion_valid(outputs, labels.double())
 
         if dataloader_idx == 0:
             return {"validation_loss": validation_loss}
@@ -233,9 +266,5 @@ class PASA(pl.LightningModule):
     # raise NotImplementedError
 
     def configure_optimizers(self):
-        # Dynamically instantiates the optimizer given the configs
-        optimizer = getattr(torch.optim, self.hparams.optimizer)(
-            self.parameters(), **self.hparams.optimizer_configs
-        )
-
+        optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
         return optimizer