diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py
index 3ee0b92164b5531b65049b94e71b01b07e2ad27e..47324199ff6daf3475d96e5999063621493b985b 100644
--- a/src/ptbench/configs/models/pasa.py
+++ b/src/ptbench/configs/models/pasa.py
@@ -14,6 +14,7 @@ Reference: [PASA-2019]_
 from torch import empty
 from torch.nn import BCEWithLogitsLoss
 
+from ...data.transforms import ElasticDeformation
 from ...models.pasa import PASA
 
 # config
@@ -26,5 +27,13 @@ optimizer = "Adam"
 criterion = BCEWithLogitsLoss(pos_weight=empty(1))
 criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
 
+augmentation_transforms = [ElasticDeformation(p=0.8)]
+
 # model
-model = PASA(criterion, criterion_valid, optimizer, optimizer_configs)
+model = PASA(
+    criterion,
+    criterion_valid,
+    optimizer,
+    optimizer_configs,
+    augmentation_transforms=augmentation_transforms,
+)
diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 98e003c5d04acf6f2aecce03479b0b4d2acf7024..8c297823fb649d6f750c2799416504b335b24f60 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -373,7 +373,7 @@ class CachingDataModule(lightning.LightningDataModule):
         )  # should only be true if GPU available and using it
 
         # datasets that have been setup() for the current stage
-        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
+        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}  # type: ignore[no-redef]
 
     @property
     def parallel(self) -> int:
@@ -387,7 +387,7 @@ class CachingDataModule(lightning.LightningDataModule):
             value
         )
         # datasets that have been setup() for the current stage
-        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
+        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}  # type: ignore[no-redef]
 
     def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
         """Coherently sets the batch-chunk-size after validation.
@@ -527,7 +527,7 @@ class CachingDataModule(lightning.LightningDataModule):
             * ``predict``: uses only the test dataset
         """
 
-        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
+        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}  # type: ignore[no-redef]
 
     def train_dataloader(self) -> torch.utils.data.DataLoader:
         """Returns the train data loader."""
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
index 8d94329247f866a29aa9b07c05952e6d2fbfa296..793c3d417a069e97d637069bf95f1ab8571c69a9 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/data/shenzhen/default.py
@@ -16,7 +16,6 @@ import importlib.resources
 
 from ..datamodule import CachingDataModule
 from ..split import JSONDatabaseSplit
-from ..transforms import ElasticDeformation
 from .raw_data_loader import raw_data_loader
 
 datamodule = CachingDataModule(
@@ -28,7 +27,6 @@ datamodule = CachingDataModule(
     raw_data_loader=raw_data_loader,
     cache_samples=False,
     # train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
-    data_augmentations=[ElasticDeformation(p=0.8)],
     # model_transforms = [],
     # batch_size = 1,
     # batch_chunk_count = 1,
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index fbc73f81c7018a319d50203d7d46c9326fe12351..76327670d047900f10b8a94d90a7bc95a44fe2a0 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -25,15 +25,14 @@ class PASA(pl.LightningModule):
         criterion_valid,
         optimizer,
         optimizer_configs,
+        augmentation_transforms,
     ):
         super().__init__()
 
-        # Saves all hyper parameters declared on __init__ into ``self.hparams`.
-        # You can access those by their name, like `self.hparams.criterion`
-        self.save_hyperparameters()
-
         self.name = "pasa"
 
+        self.augmentation_transforms = augmentation_transforms
+
         self.normalizer = None
 
         # First convolution block
@@ -159,7 +158,8 @@ class PASA(pl.LightningModule):
             labels = torch.reshape(labels, (labels.shape[0], 1))
 
         # Forward pass on the network
-        outputs = self(images)
+        augmented_images = self.augmentation_transforms(images)
+        outputs = self(augmented_images)
 
         # Manually move criterion to selected device, since not part of the model.
         self.hparams.criterion = self.hparams.criterion.to(self.device)
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 01f294d7c93d9550f9be80a1229ae00c363f2579..9b743d64cdeae69cb50e65fedbe169325f353764 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -221,6 +221,7 @@ def train(
     parallel,
     monitoring_interval,
     resume_from,
+    **_,
 ):
     """Trains an CNN to perform image classification.