diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py index 3ee0b92164b5531b65049b94e71b01b07e2ad27e..47324199ff6daf3475d96e5999063621493b985b 100644 --- a/src/ptbench/configs/models/pasa.py +++ b/src/ptbench/configs/models/pasa.py @@ -14,6 +14,7 @@ Reference: [PASA-2019]_ from torch import empty from torch.nn import BCEWithLogitsLoss +from ...data.transforms import ElasticDeformation from ...models.pasa import PASA # config @@ -26,5 +27,13 @@ optimizer = "Adam" criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) +augmentation_transforms = [ElasticDeformation(p=0.8)] + # model -model = PASA(criterion, criterion_valid, optimizer, optimizer_configs) +model = PASA( + criterion, + criterion_valid, + optimizer, + optimizer_configs, + augmentation_transforms=augmentation_transforms, +) diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 98e003c5d04acf6f2aecce03479b0b4d2acf7024..8c297823fb649d6f750c2799416504b335b24f60 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -373,7 +373,7 @@ class CachingDataModule(lightning.LightningDataModule): ) # should only be true if GPU available and using it # datasets that have been setup() for the current stage - self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} + self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef] @property def parallel(self) -> int: @@ -387,7 +387,7 @@ class CachingDataModule(lightning.LightningDataModule): value ) # datasets that have been setup() for the current stage - self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} + self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef] def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None: """Coherently sets the batch-chunk-size after validation. @@ -527,7 +527,7 @@ class CachingDataModule(lightning.LightningDataModule): * ``predict``: uses only the test dataset """ - self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} + self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef] def train_dataloader(self) -> torch.utils.data.DataLoader: """Returns the train data loader.""" diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py index 8d94329247f866a29aa9b07c05952e6d2fbfa296..793c3d417a069e97d637069bf95f1ab8571c69a9 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/data/shenzhen/default.py @@ -16,7 +16,6 @@ import importlib.resources from ..datamodule import CachingDataModule from ..split import JSONDatabaseSplit -from ..transforms import ElasticDeformation from .raw_data_loader import raw_data_loader datamodule = CachingDataModule( @@ -28,7 +27,6 @@ datamodule = CachingDataModule( raw_data_loader=raw_data_loader, cache_samples=False, # train_sampler: typing.Optional[torch.utils.data.Sampler] = None, - data_augmentations=[ElasticDeformation(p=0.8)], # model_transforms = [], # batch_size = 1, # batch_chunk_count = 1, diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index fbc73f81c7018a319d50203d7d46c9326fe12351..76327670d047900f10b8a94d90a7bc95a44fe2a0 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -25,15 +25,14 @@ class PASA(pl.LightningModule): criterion_valid, optimizer, optimizer_configs, + augmentation_transforms, ): super().__init__() - # Saves all hyper parameters declared on __init__ into ``self.hparams`. - # You can access those by their name, like `self.hparams.criterion` - self.save_hyperparameters() - self.name = "pasa" + self.augmentation_transforms = augmentation_transforms + self.normalizer = None # First convolution block @@ -159,7 +158,8 @@ class PASA(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - outputs = self(images) + augmented_images = self.augmentation_transforms(images) + outputs = self(augmented_images) # Manually move criterion to selected device, since not part of the model. self.hparams.criterion = self.hparams.criterion.to(self.device) diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 01f294d7c93d9550f9be80a1229ae00c363f2579..9b743d64cdeae69cb50e65fedbe169325f353764 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -221,6 +221,7 @@ def train( parallel, monitoring_interval, resume_from, + **_, ): """Trains an CNN to perform image classification.