diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 0ab3c36caa3692c86750a70082e36da1a1d8c71d..82c2d19a56382ff82aaa4d5997943084a395dcc4 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -498,20 +498,6 @@ class ConcatDataModule(lightning.LightningDataModule):
     DatasetDictionary: typing.TypeAlias = dict[str, Dataset]
     """A dictionary of datasets mapping names to actual datasets."""
 
-    model_transforms: list[Transform] | None
-    """Transforms required to fit data into the model.
-
-    A list of transforms (torch modules) that will be applied after raw-
-    data-loading. and just before data is fed into the model or eventual
-    data-augmentation transformations for all data loaders produced by
-    this data module.  This part of the pipeline receives data as output
-    by the raw-data-loader, or model-related transforms (e.g. resize
-    adaptions), if any is specified.  If data is cached, it is cached
-    **after** model-transforms are applied, as that is a potential
-    memory saver (e.g., if it contains a resizing operation to smaller
-    images).
-    """
-
     def __init__(
         self,
         splits: ConcatDatabaseSplit,
@@ -535,7 +521,8 @@ class ConcatDataModule(lightning.LightningDataModule):
         self.cache_samples = cache_samples
         self._train_sampler = None
         self.balance_sampler_by_class = balance_sampler_by_class
-        self.model_transforms: list[Transform] | None = None
+
+        self._model_transforms: list[Transform] | None = None
 
         self.drop_incomplete_batch = drop_incomplete_batch
         self.parallel = parallel  # immutable, otherwise would need to call
@@ -602,8 +589,35 @@ class ConcatDataModule(lightning.LightningDataModule):
                 "multiprocessing_context"
             ] = multiprocessing.get_context("spawn")
 
+    @property
+    def model_transforms(self) -> list[Transform] | None:
+        """Transforms required to fit data into the model.
+
+        A list of transforms (torch modules) that will be applied after
+        raw- data-loading. and just before data is fed into the model or
+        eventual data-augmentation transformations for all data loaders
+        produced by this data module.  This part of the pipeline
+        receives data as output by the raw-data-loader, or model-related
+        transforms (e.g. resize adaptions), if any is specified.  If
+        data is cached, it is cached **after** model-transforms are
+        applied, as that is a potential memory saver (e.g., if it
+        contains a resizing operation to smaller images).
+        """
+        return self._model_transforms
+
+    @model_transforms.setter
+    def model_transforms(self, value: list[Transform] | None):
+        old_value = self._model_transforms
+        self._model_transforms = value
+
         # datasets that have been setup() for the current stage are reset
-        self._datasets = {}
+        if value != old_value and len(self._datasets):
+            logger.warning(
+                f"Reseting {len(self._datasets)} loaded datasets due "
+                "to changes in model-transform properties.  If you were caching "
+                "data loading, this will (eventually) trigger a reload."
+            )
+            self._datasets = {}
 
     @property
     def balance_sampler_by_class(self):
@@ -801,8 +815,7 @@ class ConcatDataModule(lightning.LightningDataModule):
             * ``test``: uses only the test dataset
             * ``predict``: uses only the test dataset
         """
-
-        self._datasets = {}
+        pass
 
     def train_dataloader(self) -> DataLoader:
         """Returns the train data loader."""