From 5f28a67336bf19537733a48b601835ebeadc0513 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 3 Jul 2023 14:53:41 +0200
Subject: [PATCH] Make augmentation transforms part of the model

---
 src/ptbench/configs/models/pasa.py   | 11 ++++++++++-
 src/ptbench/data/datamodule.py       |  6 +++---
 src/ptbench/data/shenzhen/default.py |  2 --
 src/ptbench/models/pasa.py           | 10 +++++-----
 src/ptbench/scripts/train.py         |  1 +
 5 files changed, 19 insertions(+), 11 deletions(-)

diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py
index 3ee0b921..47324199 100644
--- a/src/ptbench/configs/models/pasa.py
+++ b/src/ptbench/configs/models/pasa.py
@@ -14,6 +14,7 @@ Reference: [PASA-2019]_
 from torch import empty
 from torch.nn import BCEWithLogitsLoss
 
+from ...data.transforms import ElasticDeformation
 from ...models.pasa import PASA
 
 # config
@@ -26,5 +27,13 @@ optimizer = "Adam"
 criterion = BCEWithLogitsLoss(pos_weight=empty(1))
 criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
 
+augmentation_transforms = [ElasticDeformation(p=0.8)]
+
 # model
-model = PASA(criterion, criterion_valid, optimizer, optimizer_configs)
+model = PASA(
+    criterion,
+    criterion_valid,
+    optimizer,
+    optimizer_configs,
+    augmentation_transforms=augmentation_transforms,
+)
diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 98e003c5..8c297823 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -373,7 +373,7 @@ class CachingDataModule(lightning.LightningDataModule):
         )  # should only be true if GPU available and using it
 
         # datasets that have been setup() for the current stage
-        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
+        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}  # type: ignore[no-redef]
 
     @property
     def parallel(self) -> int:
@@ -387,7 +387,7 @@ class CachingDataModule(lightning.LightningDataModule):
             value
         )
         # datasets that have been setup() for the current stage
-        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
+        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}  # type: ignore[no-redef]
 
     def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
         """Coherently sets the batch-chunk-size after validation.
@@ -527,7 +527,7 @@ class CachingDataModule(lightning.LightningDataModule):
             * ``predict``: uses only the test dataset
         """
 
-        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
+        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}  # type: ignore[no-redef]
 
     def train_dataloader(self) -> torch.utils.data.DataLoader:
         """Returns the train data loader."""
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
index 8d943292..793c3d41 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/data/shenzhen/default.py
@@ -16,7 +16,6 @@ import importlib.resources
 
 from ..datamodule import CachingDataModule
 from ..split import JSONDatabaseSplit
-from ..transforms import ElasticDeformation
 from .raw_data_loader import raw_data_loader
 
 datamodule = CachingDataModule(
@@ -28,7 +27,6 @@ datamodule = CachingDataModule(
     raw_data_loader=raw_data_loader,
     cache_samples=False,
     # train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
-    data_augmentations=[ElasticDeformation(p=0.8)],
     # model_transforms = [],
     # batch_size = 1,
     # batch_chunk_count = 1,
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index fbc73f81..76327670 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -25,15 +25,14 @@ class PASA(pl.LightningModule):
         criterion_valid,
         optimizer,
         optimizer_configs,
+        augmentation_transforms,
     ):
         super().__init__()
 
-        # Saves all hyper parameters declared on __init__ into ``self.hparams`.
-        # You can access those by their name, like `self.hparams.criterion`
-        self.save_hyperparameters()
-
         self.name = "pasa"
 
+        self.augmentation_transforms = augmentation_transforms
+
         self.normalizer = None
 
         # First convolution block
@@ -159,7 +158,8 @@ class PASA(pl.LightningModule):
             labels = torch.reshape(labels, (labels.shape[0], 1))
 
         # Forward pass on the network
-        outputs = self(images)
+        augmented_images = self.augmentation_transforms(images)
+        outputs = self(augmented_images)
 
         # Manually move criterion to selected device, since not part of the model.
         self.hparams.criterion = self.hparams.criterion.to(self.device)
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 01f294d7..9b743d64 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -221,6 +221,7 @@ def train(
     parallel,
     monitoring_interval,
     resume_from,
+    **_,
 ):
     """Trains an CNN to perform image classification.
 
-- 
GitLab