From 5f28a67336bf19537733a48b601835ebeadc0513 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 3 Jul 2023 14:53:41 +0200 Subject: [PATCH] Make augmentation transforms part of the model --- src/ptbench/configs/models/pasa.py | 11 ++++++++++- src/ptbench/data/datamodule.py | 6 +++--- src/ptbench/data/shenzhen/default.py | 2 -- src/ptbench/models/pasa.py | 10 +++++----- src/ptbench/scripts/train.py | 1 + 5 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py index 3ee0b921..47324199 100644 --- a/src/ptbench/configs/models/pasa.py +++ b/src/ptbench/configs/models/pasa.py @@ -14,6 +14,7 @@ Reference: [PASA-2019]_ from torch import empty from torch.nn import BCEWithLogitsLoss +from ...data.transforms import ElasticDeformation from ...models.pasa import PASA # config @@ -26,5 +27,13 @@ optimizer = "Adam" criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) +augmentation_transforms = [ElasticDeformation(p=0.8)] + # model -model = PASA(criterion, criterion_valid, optimizer, optimizer_configs) +model = PASA( + criterion, + criterion_valid, + optimizer, + optimizer_configs, + augmentation_transforms=augmentation_transforms, +) diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 98e003c5..8c297823 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -373,7 +373,7 @@ class CachingDataModule(lightning.LightningDataModule): ) # should only be true if GPU available and using it # datasets that have been setup() for the current stage - self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} + self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef] @property def parallel(self) -> int: @@ -387,7 +387,7 @@ class CachingDataModule(lightning.LightningDataModule): value ) # datasets that have been setup() for the current stage - self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} + self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef] def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None: """Coherently sets the batch-chunk-size after validation. @@ -527,7 +527,7 @@ class CachingDataModule(lightning.LightningDataModule): * ``predict``: uses only the test dataset """ - self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} + self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef] def train_dataloader(self) -> torch.utils.data.DataLoader: """Returns the train data loader.""" diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py index 8d943292..793c3d41 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/data/shenzhen/default.py @@ -16,7 +16,6 @@ import importlib.resources from ..datamodule import CachingDataModule from ..split import JSONDatabaseSplit -from ..transforms import ElasticDeformation from .raw_data_loader import raw_data_loader datamodule = CachingDataModule( @@ -28,7 +27,6 @@ datamodule = CachingDataModule( raw_data_loader=raw_data_loader, cache_samples=False, # train_sampler: typing.Optional[torch.utils.data.Sampler] = None, - data_augmentations=[ElasticDeformation(p=0.8)], # model_transforms = [], # batch_size = 1, # batch_chunk_count = 1, diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index fbc73f81..76327670 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -25,15 +25,14 @@ class PASA(pl.LightningModule): criterion_valid, optimizer, optimizer_configs, + augmentation_transforms, ): super().__init__() - # Saves all hyper parameters declared on __init__ into ``self.hparams`. - # You can access those by their name, like `self.hparams.criterion` - self.save_hyperparameters() - self.name = "pasa" + self.augmentation_transforms = augmentation_transforms + self.normalizer = None # First convolution block @@ -159,7 +158,8 @@ class PASA(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - outputs = self(images) + augmented_images = self.augmentation_transforms(images) + outputs = self(augmented_images) # Manually move criterion to selected device, since not part of the model. self.hparams.criterion = self.hparams.criterion.to(self.device) diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 01f294d7..9b743d64 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -221,6 +221,7 @@ def train( parallel, monitoring_interval, resume_from, + **_, ): """Trains an CNN to perform image classification. -- GitLab