diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 0ab3c36caa3692c86750a70082e36da1a1d8c71d..82c2d19a56382ff82aaa4d5997943084a395dcc4 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -498,20 +498,6 @@ class ConcatDataModule(lightning.LightningDataModule): DatasetDictionary: typing.TypeAlias = dict[str, Dataset] """A dictionary of datasets mapping names to actual datasets.""" - model_transforms: list[Transform] | None - """Transforms required to fit data into the model. - - A list of transforms (torch modules) that will be applied after raw- - data-loading. and just before data is fed into the model or eventual - data-augmentation transformations for all data loaders produced by - this data module. This part of the pipeline receives data as output - by the raw-data-loader, or model-related transforms (e.g. resize - adaptions), if any is specified. If data is cached, it is cached - **after** model-transforms are applied, as that is a potential - memory saver (e.g., if it contains a resizing operation to smaller - images). - """ - def __init__( self, splits: ConcatDatabaseSplit, @@ -535,7 +521,8 @@ class ConcatDataModule(lightning.LightningDataModule): self.cache_samples = cache_samples self._train_sampler = None self.balance_sampler_by_class = balance_sampler_by_class - self.model_transforms: list[Transform] | None = None + + self._model_transforms: list[Transform] | None = None self.drop_incomplete_batch = drop_incomplete_batch self.parallel = parallel # immutable, otherwise would need to call @@ -602,8 +589,35 @@ class ConcatDataModule(lightning.LightningDataModule): "multiprocessing_context" ] = multiprocessing.get_context("spawn") + @property + def model_transforms(self) -> list[Transform] | None: + """Transforms required to fit data into the model. + + A list of transforms (torch modules) that will be applied after + raw- data-loading. and just before data is fed into the model or + eventual data-augmentation transformations for all data loaders + produced by this data module. This part of the pipeline + receives data as output by the raw-data-loader, or model-related + transforms (e.g. resize adaptions), if any is specified. If + data is cached, it is cached **after** model-transforms are + applied, as that is a potential memory saver (e.g., if it + contains a resizing operation to smaller images). + """ + return self._model_transforms + + @model_transforms.setter + def model_transforms(self, value: list[Transform] | None): + old_value = self._model_transforms + self._model_transforms = value + # datasets that have been setup() for the current stage are reset - self._datasets = {} + if value != old_value and len(self._datasets): + logger.warning( + f"Reseting {len(self._datasets)} loaded datasets due " + "to changes in model-transform properties. If you were caching " + "data loading, this will (eventually) trigger a reload." + ) + self._datasets = {} @property def balance_sampler_by_class(self): @@ -801,8 +815,7 @@ class ConcatDataModule(lightning.LightningDataModule): * ``test``: uses only the test dataset * ``predict``: uses only the test dataset """ - - self._datasets = {} + pass def train_dataloader(self) -> DataLoader: """Returns the train data loader."""