Skip to content
Snippets Groups Projects
Commit 8afb13e5 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.datamodule] Double-check model_transforms are set before datasets are...

[data.datamodule] Double-check model_transforms are set before datasets are instantiated; Remove model_transforms from constructor (non-sensical); Improve documentation on model-transforms
parent 50fce2cd
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -338,14 +338,6 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -338,14 +338,6 @@ class CachingDataModule(lightning.LightningDataModule):
validation to balance sample picking probability, making sample validation to balance sample picking probability, making sample
across classes **and** datasets equitable. across classes **and** datasets equitable.
model_transforms
A list of transforms (torch modules) that will be applied after
raw-data-loading, and just before data is fed into the model or
eventual data-augmentation transformations for all data loaders
produced by this data module. This part of the pipeline receives data
as output by the raw-data-loader, or model-related transforms (e.g.
resize adaptions), if any is specified.
batch_size batch_size
Number of samples in every **training** batch (this parameter affects Number of samples in every **training** batch (this parameter affects
memory requirements for the network). If the number of samples in the memory requirements for the network). If the number of samples in the
...@@ -382,6 +374,21 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -382,6 +374,21 @@ class CachingDataModule(lightning.LightningDataModule):
multiprocessing data loading. Set to 0 to enable as many data loading multiprocessing data loading. Set to 0 to enable as many data loading
instances as processing cores as available in the system. Set to >= 1 instances as processing cores as available in the system. Set to >= 1
to enable that many multiprocessing instances for data loading. to enable that many multiprocessing instances for data loading.
Attributes
----------
model_transforms
A list of transforms (torch modules) that will be applied after
raw-data-loading, and just before data is fed into the model or
eventual data-augmentation transformations for all data loaders
produced by this data module. This part of the pipeline receives data
as output by the raw-data-loader, or model-related transforms (e.g.
resize adaptions), if any is specified. If data is cached, it is
cached **after** model-transforms are applied, as that is a potential
memory saver (e.g., if it contains a resizing operation to smaller
images).
""" """
DatasetDictionary = dict[str, Dataset] DatasetDictionary = dict[str, Dataset]
...@@ -392,7 +399,6 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -392,7 +399,6 @@ class CachingDataModule(lightning.LightningDataModule):
raw_data_loader: RawDataLoader, raw_data_loader: RawDataLoader,
cache_samples: bool = False, cache_samples: bool = False,
balance_sampler_by_class: bool = False, balance_sampler_by_class: bool = False,
model_transforms: list[Transform] = [],
batch_size: int = 1, batch_size: int = 1,
batch_chunk_count: int = 1, batch_chunk_count: int = 1,
drop_incomplete_batch: bool = False, drop_incomplete_batch: bool = False,
...@@ -407,7 +413,7 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -407,7 +413,7 @@ class CachingDataModule(lightning.LightningDataModule):
self.cache_samples = cache_samples self.cache_samples = cache_samples
self._train_sampler = None self._train_sampler = None
self.balance_sampler_by_class = balance_sampler_by_class self.balance_sampler_by_class = balance_sampler_by_class
self.model_transforms = model_transforms self.model_transforms: list[Transform] | None = None
self.drop_incomplete_batch = drop_incomplete_batch self.drop_incomplete_batch = drop_incomplete_batch
self.parallel = parallel # immutable, otherwise would need to call self.parallel = parallel # immutable, otherwise would need to call
...@@ -551,6 +557,13 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -551,6 +557,13 @@ class CachingDataModule(lightning.LightningDataModule):
Name of the dataset to setup. Name of the dataset to setup.
""" """
if self.model_transforms is None:
raise RuntimeError(
"Parameter `model_transforms` has not yet been "
"set. If you do not have model transforms, then "
"set it to an empty list."
)
if name in self._datasets: if name in self._datasets:
logger.info( logger.info(
f"Dataset `{name}` is already setup. " f"Dataset `{name}` is already setup. "
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment