Skip to content
Snippets Groups Projects
Commit d3d65b93 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.datamodule] Only reset datasets if model_transforms change

parent f0f7784b
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
Pipeline #77168 passed
...@@ -498,20 +498,6 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -498,20 +498,6 @@ class ConcatDataModule(lightning.LightningDataModule):
DatasetDictionary: typing.TypeAlias = dict[str, Dataset] DatasetDictionary: typing.TypeAlias = dict[str, Dataset]
"""A dictionary of datasets mapping names to actual datasets.""" """A dictionary of datasets mapping names to actual datasets."""
model_transforms: list[Transform] | None
"""Transforms required to fit data into the model.
A list of transforms (torch modules) that will be applied after raw-
data-loading. and just before data is fed into the model or eventual
data-augmentation transformations for all data loaders produced by
this data module. This part of the pipeline receives data as output
by the raw-data-loader, or model-related transforms (e.g. resize
adaptions), if any is specified. If data is cached, it is cached
**after** model-transforms are applied, as that is a potential
memory saver (e.g., if it contains a resizing operation to smaller
images).
"""
def __init__( def __init__(
self, self,
splits: ConcatDatabaseSplit, splits: ConcatDatabaseSplit,
...@@ -535,7 +521,8 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -535,7 +521,8 @@ class ConcatDataModule(lightning.LightningDataModule):
self.cache_samples = cache_samples self.cache_samples = cache_samples
self._train_sampler = None self._train_sampler = None
self.balance_sampler_by_class = balance_sampler_by_class self.balance_sampler_by_class = balance_sampler_by_class
self.model_transforms: list[Transform] | None = None
self._model_transforms: list[Transform] | None = None
self.drop_incomplete_batch = drop_incomplete_batch self.drop_incomplete_batch = drop_incomplete_batch
self.parallel = parallel # immutable, otherwise would need to call self.parallel = parallel # immutable, otherwise would need to call
...@@ -602,8 +589,35 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -602,8 +589,35 @@ class ConcatDataModule(lightning.LightningDataModule):
"multiprocessing_context" "multiprocessing_context"
] = multiprocessing.get_context("spawn") ] = multiprocessing.get_context("spawn")
@property
def model_transforms(self) -> list[Transform] | None:
"""Transforms required to fit data into the model.
A list of transforms (torch modules) that will be applied after
raw- data-loading. and just before data is fed into the model or
eventual data-augmentation transformations for all data loaders
produced by this data module. This part of the pipeline
receives data as output by the raw-data-loader, or model-related
transforms (e.g. resize adaptions), if any is specified. If
data is cached, it is cached **after** model-transforms are
applied, as that is a potential memory saver (e.g., if it
contains a resizing operation to smaller images).
"""
return self._model_transforms
@model_transforms.setter
def model_transforms(self, value: list[Transform] | None):
old_value = self._model_transforms
self._model_transforms = value
# datasets that have been setup() for the current stage are reset # datasets that have been setup() for the current stage are reset
self._datasets = {} if value != old_value and len(self._datasets):
logger.warning(
f"Reseting {len(self._datasets)} loaded datasets due "
"to changes in model-transform properties. If you were caching "
"data loading, this will (eventually) trigger a reload."
)
self._datasets = {}
@property @property
def balance_sampler_by_class(self): def balance_sampler_by_class(self):
...@@ -801,8 +815,7 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -801,8 +815,7 @@ class ConcatDataModule(lightning.LightningDataModule):
* ``test``: uses only the test dataset * ``test``: uses only the test dataset
* ``predict``: uses only the test dataset * ``predict``: uses only the test dataset
""" """
pass
self._datasets = {}
def train_dataloader(self) -> DataLoader: def train_dataloader(self) -> DataLoader:
"""Returns the train data loader.""" """Returns the train data loader."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment