Skip to content
Snippets Groups Projects
Commit 5f28a673 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Make augmentation transforms part of the model

parent c464ae3e
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
......@@ -14,6 +14,7 @@ Reference: [PASA-2019]_
from torch import empty
from torch.nn import BCEWithLogitsLoss
from ...data.transforms import ElasticDeformation
from ...models.pasa import PASA
# config
......@@ -26,5 +27,13 @@ optimizer = "Adam"
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
augmentation_transforms = [ElasticDeformation(p=0.8)]
# model
model = PASA(criterion, criterion_valid, optimizer, optimizer_configs)
model = PASA(
criterion,
criterion_valid,
optimizer,
optimizer_configs,
augmentation_transforms=augmentation_transforms,
)
......@@ -373,7 +373,7 @@ class CachingDataModule(lightning.LightningDataModule):
) # should only be true if GPU available and using it
# datasets that have been setup() for the current stage
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef]
@property
def parallel(self) -> int:
......@@ -387,7 +387,7 @@ class CachingDataModule(lightning.LightningDataModule):
value
)
# datasets that have been setup() for the current stage
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef]
def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
"""Coherently sets the batch-chunk-size after validation.
......@@ -527,7 +527,7 @@ class CachingDataModule(lightning.LightningDataModule):
* ``predict``: uses only the test dataset
"""
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef]
def train_dataloader(self) -> torch.utils.data.DataLoader:
"""Returns the train data loader."""
......
......@@ -16,7 +16,6 @@ import importlib.resources
from ..datamodule import CachingDataModule
from ..split import JSONDatabaseSplit
from ..transforms import ElasticDeformation
from .raw_data_loader import raw_data_loader
datamodule = CachingDataModule(
......@@ -28,7 +27,6 @@ datamodule = CachingDataModule(
raw_data_loader=raw_data_loader,
cache_samples=False,
# train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
data_augmentations=[ElasticDeformation(p=0.8)],
# model_transforms = [],
# batch_size = 1,
# batch_chunk_count = 1,
......
......@@ -25,15 +25,14 @@ class PASA(pl.LightningModule):
criterion_valid,
optimizer,
optimizer_configs,
augmentation_transforms,
):
super().__init__()
# Saves all hyper parameters declared on __init__ into ``self.hparams`.
# You can access those by their name, like `self.hparams.criterion`
self.save_hyperparameters()
self.name = "pasa"
self.augmentation_transforms = augmentation_transforms
self.normalizer = None
# First convolution block
......@@ -159,7 +158,8 @@ class PASA(pl.LightningModule):
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(images)
augmented_images = self.augmentation_transforms(images)
outputs = self(augmented_images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion = self.hparams.criterion.to(self.device)
......
......@@ -221,6 +221,7 @@ def train(
parallel,
monitoring_interval,
resume_from,
**_,
):
"""Trains an CNN to perform image classification.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment