diff --git a/pyproject.toml b/pyproject.toml
index a418e59bd1e3f464ef4be58cd28431ae0efcb61e..549b347ebf05ef8554a725196a8576e8e809c2af 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -118,7 +118,7 @@ montgomery_rs_f8 = "ptbench.configs.datasets.montgomery_RS.fold_8"
 montgomery_rs_f9 = "ptbench.configs.datasets.montgomery_RS.fold_9"
 # shenzhen dataset (and cross-validation folds)
 shenzhen = "ptbench.configs.datasets.shenzhen.default"
-shenzhen_rgb = "ptbench.data.shenzhen.rgb"
+shenzhen_rgb = "ptbench.configs.datasets.shenzhen.rgb"
 shenzhen_f0 = "ptbench.data.shenzhen.fold_0"
 shenzhen_f1 = "ptbench.data.shenzhen.fold_1"
 shenzhen_f2 = "ptbench.data.shenzhen.fold_2"
diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py
index 6f5b31ff41c2b226059794d8ef1fdffd7828e8b4..11c285bcc93380291f273e191281035462eb0d6f 100644
--- a/src/ptbench/configs/datasets/shenzhen/default.py
+++ b/src/ptbench/configs/datasets/shenzhen/default.py
@@ -55,6 +55,7 @@ class DefaultModule(BaseDataModule):
             fieldnames=("data", "label"),
             loader=samples_loader,
         )
+
         (
             self.train_dataset,
             self.validation_dataset,
diff --git a/src/ptbench/configs/datasets/shenzhen/rgb.py b/src/ptbench/configs/datasets/shenzhen/rgb.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ceb952e28c0a1181772565b425043a9a08cb646
--- /dev/null
+++ b/src/ptbench/configs/datasets/shenzhen/rgb.py
@@ -0,0 +1,74 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Shenzhen dataset for TB detection (cross validation fold 0, RGB)
+
+* Split reference: first 80% of TB and healthy CXR for "train", rest for "test"
+* This configuration resolution: 512 x 512 (default)
+* See :py:mod:`ptbench.data.shenzhen` for dataset details
+"""
+
+from clapper.logging import setup
+from torchvision import transforms
+
+from ....data import return_subsets
+from ....data.base_datamodule import BaseDataModule
+from ....data.dataset import JSONDataset
+from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+class DefaultModule(BaseDataModule):
+    def __init__(
+        self,
+        train_batch_size=1,
+        predict_batch_size=1,
+        drop_incomplete_batch=False,
+        cache_samples=False,
+        multiproc_kwargs=None,
+    ):
+        super().__init__(
+            train_batch_size=train_batch_size,
+            predict_batch_size=predict_batch_size,
+            drop_incomplete_batch=drop_incomplete_batch,
+            multiproc_kwargs=multiproc_kwargs,
+        )
+
+        self.cache_samples = cache_samples
+
+        self.post_transforms = [
+            transforms.ToPILImage(),
+            transforms.Lambda(lambda x: x.convert("RGB")),
+            transforms.ToTensor(),
+        ]
+
+    def setup(self, stage: str):
+        if self.cache_samples:
+            logger.info(
+                "Argument cache_samples set to True. Samples will be loaded in memory."
+            )
+            samples_loader = _cached_loader
+        else:
+            logger.info(
+                "Argument cache_samples set to False. Samples will be loaded at runtime."
+            )
+            samples_loader = _delayed_loader
+
+        self.json_dataset = JSONDataset(
+            protocols=_protocols,
+            fieldnames=("data", "label"),
+            loader=samples_loader,
+            post_transforms=self.post_transforms,
+        )
+
+        (
+            self.train_dataset,
+            self.validation_dataset,
+            self.extra_validation_datasets,
+            self.predict_dataset,
+        ) = return_subsets(self.json_dataset, "default")
+
+
+datamodule = DefaultModule
diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py
index b1ffcadadff653340b39d4278ea9a6d1db795106..07f18acf8b1ed3c68be535c621d371f026c29ff7 100644
--- a/src/ptbench/data/dataset.py
+++ b/src/ptbench/data/dataset.py
@@ -75,7 +75,7 @@ class JSONDataset:
         * ``data``: which contains the data associated witht this sample
     """
 
-    def __init__(self, protocols, fieldnames, loader):
+    def __init__(self, protocols, fieldnames, loader, post_transforms=[]):
         if isinstance(protocols, dict):
             self._protocols = protocols
         else:
@@ -87,6 +87,7 @@ class JSONDataset:
             }
         self.fieldnames = fieldnames
         self._loader = loader
+        self.post_transforms = post_transforms
 
     def check(self, limit=0):
         """For each protocol, check if all data can be correctly accessed.
@@ -176,6 +177,7 @@ class JSONDataset:
                 self._loader(
                     dict(protocol=protocol, subset=subset, order=n),
                     dict(zip(self.fieldnames, k)),
+                    self.post_transforms,
                 )
                 for n, k in tqdm.tqdm(enumerate(samples))
             ]
diff --git a/src/ptbench/data/loader.py b/src/ptbench/data/loader.py
index 931c62912ec8beffa21c34d39d056a8bcac17506..a11aefee77becbbbb07e15a25d5404594c16999a 100644
--- a/src/ptbench/data/loader.py
+++ b/src/ptbench/data/loader.py
@@ -70,15 +70,15 @@ def load_pil_rgb(path):
     return load_pil(path).convert("RGB")
 
 
-def make_cached(sample, loader, key=None):
+def make_cached(sample, loader, additional_transforms=[], key=None):
     return Sample(
-        loader(sample),
+        loader(sample, additional_transforms),
         key=key or sample["data"],
         label=sample["label"],
     )
 
 
-def make_delayed(sample, loader, key=None):
+def make_delayed(sample, loader, additional_transforms=[], key=None):
     """Returns a delayed-loading Sample object.
 
     Parameters
@@ -105,7 +105,7 @@ def make_delayed(sample, loader, key=None):
         sample loading.
     """
     return DelayedSample(
-        functools.partial(loader, sample),
+        functools.partial(loader, sample, additional_transforms),
         key=key or sample["data"],
         label=sample["label"],
     )
diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py
index 9abf568964b3eab377ffcfcca98cf1e31a1d0cb3..d284b1b28905594e0c9c7fd20cb31f7c566468e3 100644
--- a/src/ptbench/data/shenzhen/__init__.py
+++ b/src/ptbench/data/shenzhen/__init__.py
@@ -51,29 +51,31 @@ _datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
 _resize_size = 512
 _cc_size = 512
 
-_data_transforms = transforms.Compose(
-    [
-        RemoveBlackBorders(),
-        transforms.Resize(_resize_size),
-        transforms.CenterCrop(_cc_size),
-        transforms.ToTensor(),
-    ]
-)
+_data_transforms = [
+    RemoveBlackBorders(),
+    transforms.Resize(_resize_size),
+    transforms.CenterCrop(_cc_size),
+    transforms.ToTensor(),
+]
 
 
-def _raw_data_loader(sample):
+def _raw_data_loader(sample, additional_transforms=[]):
     raw_data = load_pil_baw(os.path.join(_datadir, sample["data"]))
+
+    base_transforms = transforms.Compose(
+        _data_transforms + additional_transforms
+    )
     return dict(
-        data=_data_transforms(raw_data),
+        data=base_transforms(raw_data),
         label=sample["label"],
     )
 
 
-def _cached_loader(context, sample):
-    return make_cached(sample, _raw_data_loader)
+def _cached_loader(context, sample, additional_transforms=[]):
+    return make_cached(sample, _raw_data_loader, additional_transforms)
 
 
-def _delayed_loader(context, sample):
+def _delayed_loader(context, sample, additional_transforms=[]):
     # "context" is ignored in this case - database is homogeneous
     # we returned delayed samples to avoid loading all images at once
-    return make_delayed(sample, _raw_data_loader)
+    return make_delayed(sample, _raw_data_loader, additional_transforms)
diff --git a/src/ptbench/data/shenzhen/rgb.py b/src/ptbench/data/shenzhen/rgb.py
deleted file mode 100644
index 7bdb8fe3ce6826fb98d0d6356f2e1b429670a3d1..0000000000000000000000000000000000000000
--- a/src/ptbench/data/shenzhen/rgb.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Shenzhen dataset for TB detection (default protocol, converted in RGB)
-
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.shenzhen` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("default", RGB=True)
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule