diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 5c94091877ff735e74942795203aa82c8931b61e..c8c33dec3e0f417c838da720cc630bc04a59bd96 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -338,14 +338,6 @@ class CachingDataModule(lightning.LightningDataModule): validation to balance sample picking probability, making sample across classes **and** datasets equitable. - model_transforms - A list of transforms (torch modules) that will be applied after - raw-data-loading, and just before data is fed into the model or - eventual data-augmentation transformations for all data loaders - produced by this data module. This part of the pipeline receives data - as output by the raw-data-loader, or model-related transforms (e.g. - resize adaptions), if any is specified. - batch_size Number of samples in every **training** batch (this parameter affects memory requirements for the network). If the number of samples in the @@ -382,6 +374,21 @@ class CachingDataModule(lightning.LightningDataModule): multiprocessing data loading. Set to 0 to enable as many data loading instances as processing cores as available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading. + + + Attributes + ---------- + + model_transforms + A list of transforms (torch modules) that will be applied after + raw-data-loading, and just before data is fed into the model or + eventual data-augmentation transformations for all data loaders + produced by this data module. This part of the pipeline receives data + as output by the raw-data-loader, or model-related transforms (e.g. + resize adaptions), if any is specified. If data is cached, it is + cached **after** model-transforms are applied, as that is a potential + memory saver (e.g., if it contains a resizing operation to smaller + images). """ DatasetDictionary = dict[str, Dataset] @@ -392,7 +399,6 @@ class CachingDataModule(lightning.LightningDataModule): raw_data_loader: RawDataLoader, cache_samples: bool = False, balance_sampler_by_class: bool = False, - model_transforms: list[Transform] = [], batch_size: int = 1, batch_chunk_count: int = 1, drop_incomplete_batch: bool = False, @@ -407,7 +413,7 @@ class CachingDataModule(lightning.LightningDataModule): self.cache_samples = cache_samples self._train_sampler = None self.balance_sampler_by_class = balance_sampler_by_class - self.model_transforms = model_transforms + self.model_transforms: list[Transform] | None = None self.drop_incomplete_batch = drop_incomplete_batch self.parallel = parallel # immutable, otherwise would need to call @@ -551,6 +557,13 @@ class CachingDataModule(lightning.LightningDataModule): Name of the dataset to setup. """ + if self.model_transforms is None: + raise RuntimeError( + "Parameter `model_transforms` has not yet been " + "set. If you do not have model transforms, then " + "set it to an empty list." + ) + if name in self._datasets: logger.info( f"Dataset `{name}` is already setup. "