diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 5c94091877ff735e74942795203aa82c8931b61e..4c92476360b9895016560976706763f1098c74d3 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -3,6 +3,7 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 
 import collections
+import functools
 import logging
 import multiprocessing
 import sys
@@ -27,6 +28,39 @@ from .typing import (
 logger = logging.getLogger(__name__)
 
 
+def _sample_size_bytes(s: Sample) -> int:
+    """Recurse into the sample and figures out its total occupance in bytes.
+
+    Parameters
+    ----------
+
+    s
+        The sample to be analyzed
+
+
+    Returns
+    -------
+
+    size
+        The size in bytes occupied by this sample
+    """
+
+    def _tensor_size_bytes(t: torch.Tensor) -> int:
+        """Returns a tensor size in bytes."""
+        return int(t.element_size() * torch.prod(torch.tensor(t.shape)))
+
+    size = int(s[0].element_size() * torch.prod(torch.tensor(s[0].shape)))
+    size += sys.getsizeof(s[1])
+
+    # check each element - if it is a tensor, then adds its total space in
+    # bytes
+    for v in s[1].values():
+        if isinstance(v, torch.Tensor):
+            size += _tensor_size_bytes(v)
+
+    return size
+
+
 class _DelayedLoadingDataset(Dataset):
     """A list that loads its samples on demand.
 
@@ -59,6 +93,15 @@ class _DelayedLoadingDataset(Dataset):
         self.loader = loader
         self.transform = torchvision.transforms.Compose(transforms)
 
+        # Tests loading and output tensor size
+        first_sample = self[0]
+        logger.info(
+            f"Delayed loading dataset (first tensor): "
+            f"{list(first_sample[0].shape)}@{first_sample[0].dtype}"
+        )
+        sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0)
+        logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
+
     def labels(self) -> list[int]:
         """Returns the integer labels for all samples in the dataset."""
         return [self.loader.label(k) for k in self.split]
@@ -75,6 +118,39 @@ class _DelayedLoadingDataset(Dataset):
             yield self[x]
 
 
+def _apply_loader_and_transforms(
+    info: typing.Any,
+    load: typing.Callable[[typing.Any], Sample],
+    model_transform: typing.Callable[[torch.Tensor], torch.Tensor],
+) -> Sample:
+    """Local wrapper to apply raw-data loading and transformation in a single
+    step.
+
+    Parameters
+    ----------
+
+    info
+        The sample information, as loaded from its split dictionary
+
+    load
+        The raw-data loader function to use for loading the sample
+
+    model_transform
+        A callable that will transform the loaded tensor into something
+        suitable for the model it will train.  Typically, this will be a
+        composed transform.
+
+
+    Returns
+    -------
+
+    sample
+        The loaded and transformed sample.
+    """
+    sample = load(info)
+    return model_transform(sample[0]), sample[1]
+
+
 class _CachedDataset(Dataset):
     """Basically, a list of preloaded samples.
 
@@ -112,27 +188,41 @@ class _CachedDataset(Dataset):
         parallel: int = -1,
         transforms: typing.Sequence[Transform] = [],
     ):
-        self.transform = torchvision.transforms.Compose(transforms)
+        self.loader = functools.partial(
+            _apply_loader_and_transforms,
+            load=loader.sample,
+            model_transform=torchvision.transforms.Compose(transforms),
+        )
 
         if parallel < 0:
             self.data = [
-                loader.sample(k) for k in tqdm.tqdm(split, unit="sample")
+                self.loader(k) for k in tqdm.tqdm(split, unit="sample")
             ]
         else:
             instances = parallel or multiprocessing.cpu_count()
             logger.info(f"Caching dataset using {instances} processes...")
             with multiprocessing.Pool(instances) as p:
                 self.data = list(
-                    tqdm.tqdm(p.imap(loader.sample, split), total=len(split))
+                    tqdm.tqdm(p.imap(self.loader, split), total=len(split))
                 )
 
+        # Estimates memory occupance
+        logger.info(
+            f"Cached dataset (first tensor): "
+            f"{list(self.data[0][0].shape)}@{self.data[0][0].dtype}"
+        )
+        sample_size_mb = _sample_size_bytes(self.data[0]) / (1024.0 * 1024.0)
+        logger.info(
+            f"Estimated RAM occupance (sample / dataset): "
+            f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb"
+        )
+
     def labels(self) -> list[int]:
         """Returns the integer labels for all samples in the dataset."""
         return [k[1]["label"] for k in self.data]
 
     def __getitem__(self, key: int) -> Sample:
-        tensor, metadata = self.data[key]
-        return self.transform(tensor), metadata
+        return self.data[key]
 
     def __len__(self):
         return len(self.data)
@@ -338,14 +428,6 @@ class CachingDataModule(lightning.LightningDataModule):
         validation to balance sample picking probability, making sample
         across classes **and** datasets equitable.
 
-    model_transforms
-        A list of transforms (torch modules) that will be applied after
-        raw-data-loading, and just before data is fed into the model or
-        eventual data-augmentation transformations for all data loaders
-        produced by this data module.  This part of the pipeline receives data
-        as output by the raw-data-loader, or model-related transforms (e.g.
-        resize adaptions), if any is specified.
-
     batch_size
         Number of samples in every **training** batch (this parameter affects
         memory requirements for the network).  If the number of samples in the
@@ -382,6 +464,21 @@ class CachingDataModule(lightning.LightningDataModule):
         multiprocessing data loading.  Set to 0 to enable as many data loading
         instances as processing cores as available in the system.  Set to >= 1
         to enable that many multiprocessing instances for data loading.
+
+
+    Attributes
+    ----------
+
+    model_transforms
+        A list of transforms (torch modules) that will be applied after
+        raw-data-loading, and just before data is fed into the model or
+        eventual data-augmentation transformations for all data loaders
+        produced by this data module.  This part of the pipeline receives data
+        as output by the raw-data-loader, or model-related transforms (e.g.
+        resize adaptions), if any is specified.  If data is cached, it is
+        cached **after** model-transforms are applied, as that is a potential
+        memory saver (e.g., if it contains a resizing operation to smaller
+        images).
     """
 
     DatasetDictionary = dict[str, Dataset]
@@ -392,7 +489,6 @@ class CachingDataModule(lightning.LightningDataModule):
         raw_data_loader: RawDataLoader,
         cache_samples: bool = False,
         balance_sampler_by_class: bool = False,
-        model_transforms: list[Transform] = [],
         batch_size: int = 1,
         batch_chunk_count: int = 1,
         drop_incomplete_batch: bool = False,
@@ -407,7 +503,7 @@ class CachingDataModule(lightning.LightningDataModule):
         self.cache_samples = cache_samples
         self._train_sampler = None
         self.balance_sampler_by_class = balance_sampler_by_class
-        self.model_transforms = model_transforms
+        self.model_transforms: list[Transform] | None = None
 
         self.drop_incomplete_batch = drop_incomplete_batch
         self.parallel = parallel  # immutable, otherwise would need to call
@@ -551,6 +647,13 @@ class CachingDataModule(lightning.LightningDataModule):
             Name of the dataset to setup.
         """
 
+        if self.model_transforms is None:
+            raise RuntimeError(
+                "Parameter `model_transforms` has not yet been "
+                "set.  If you do not have model transforms, then "
+                "set it to an empty list."
+            )
+
         if name in self._datasets:
             logger.info(
                 f"Dataset `{name}` is already setup. "
diff --git a/src/ptbench/data/image_utils.py b/src/ptbench/data/image_utils.py
index ac31b9ce7fbce85fb688b394c99d591b83049f7f..ed284afc4ab63fd804a8110ff676c29917d968b4 100644
--- a/src/ptbench/data/image_utils.py
+++ b/src/ptbench/data/image_utils.py
@@ -31,6 +31,40 @@ class SingleAutoLevel16to8:
         ).convert("L")
 
 
+def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Image:
+    """Remove black borders of CXR
+
+    Parameters
+    ----------
+        img 
+            A PIL image
+        threshold 
+            Threshold value from which borders are considered black.
+            Defaults to 0.
+
+    Returns
+    -------
+        A PIL image with black borders removed
+    """
+
+    img = numpy.asarray(img)
+
+    if len(img.shape) == 2: # single channel
+        mask = numpy.asarray(img) > threshold
+        return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
+    
+    elif len(img.shape) == 3 and img.shape[2] == 3:
+        r_mask = img[:, :, 0] > threshold
+        g_mask = img[:, :, 1] > threshold
+        b_mask = img[:, :, 2] > threshold
+
+        mask = r_mask | g_mask | b_mask
+        return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
+    
+    else:
+        raise NotImplementedError
+
+
 class RemoveBlackBorders:
     """Remove black borders of CXR."""
 
@@ -38,9 +72,7 @@ class RemoveBlackBorders:
         self.threshold = threshold
 
     def __call__(self, img):
-        img = numpy.asarray(img)
-        mask = numpy.asarray(img) > self.threshold
-        return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
+        return remove_black_borders(img, self.threshold)
 
 
 def load_pil(path: str | pathlib.Path) -> PIL.Image.Image:
diff --git a/src/ptbench/data/montgomery/__init__.py b/src/ptbench/data/montgomery/__init__.py
index 65239cbf5d908075346675ad10e7c86569383f77..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644
--- a/src/ptbench/data/montgomery/__init__.py
+++ b/src/ptbench/data/montgomery/__init__.py
@@ -1,88 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Montgomery dataset for computer-aided diagnosis.
-
-The Montgomery database has been established to foster research
-in computer-aided diagnosis of pulmonary diseases with a special
-focus on pulmonary tuberculosis (TB).
-
-* Reference: [MONTGOMERY-SHENZHEN-2014]_
-* Original resolution (height x width or width x height): 4020 x 4892
-* Split reference: none
-* Protocol ``default``:
-
-  * Training samples: 64% of TB and healthy CXR (including labels)
-  * Validation samples: 16% of TB and healthy CXR (including labels)
-  * Test samples: 20% of TB and healthy CXR (including labels)
-"""
-
-import importlib.resources
-import os
-
-from ...utils.rc import load_rc
-from .. import make_dataset
-from ..dataset import JSONDataset
-from ..loader import load_pil_baw, make_delayed
-
-_protocols = [
-    importlib.resources.files(__name__).joinpath("default.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_0.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_1.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_2.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_3.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_4.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_5.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_6.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_7.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_8.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
-]
-
-_datadir = load_rc().get("datadir.montgomery", os.path.realpath(os.curdir))
-
-
-def _raw_data_loader(sample):
-    return dict(
-        data=load_pil_baw(os.path.join(_datadir, sample["data"])),  # type: ignore
-        label=sample["label"],
-    )
-
-
-def _loader(context, sample):
-    # "context" is ignored in this case - database is homogeneous
-    # we return delayed samples to avoid loading all images at once
-    return make_delayed(sample, _raw_data_loader)
-
-
-json_dataset = JSONDataset(
-    protocols=_protocols,
-    fieldnames=("data", "label"),
-    loader=_loader,
-)
-"""Montgomery dataset object."""
-
-
-def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
-    from torchvision import transforms
-
-    from ..transforms import ElasticDeformation, RemoveBlackBorders
-
-    post_transforms = []
-    if RGB:
-        post_transforms = [
-            transforms.Lambda(lambda x: x.convert("RGB")),
-            transforms.ToTensor(),
-        ]
-
-    return make_dataset(
-        [json_dataset.subsets(protocol)],
-        [
-            RemoveBlackBorders(),
-            transforms.Resize(resize_size),
-            transforms.CenterCrop(cc_size),
-        ],
-        [ElasticDeformation(p=0.8)],
-        post_transforms,
-    )
diff --git a/src/ptbench/data/montgomery/default.py b/src/ptbench/data/montgomery/default.py
index 1f5c0809869be5f011880e808e160024b3c1c1b0..bb57b9a7e8d95f9af40d36ac5a57349c8f514846 100644
--- a/src/ptbench/data/montgomery/default.py
+++ b/src/ptbench/data/montgomery/default.py
@@ -2,46 +2,60 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Montgomery dataset for TB detection (default protocol)
+"""Montgomery datamodule for TB detection (``default`` protocol)
 
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
-"""
+The standard digital image database for Tuberculosis was created by the National
+Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s
+Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from
 
-from clapper.logging import setup
+* Database reference: [MONTGOMERY-SHENZHEN-2014]_
+* Original resolution (height x width or width x height): 4020x4892 px or 4892x4020 px
+* This split:
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+  * Split reference: None
+  * Training samples: ?? of TB and healthy CXR
+  * Validation samples: ?? of TB and healthy CXR
+  * Test samples: ?? of TB and healthy CXR
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+Data specifications:
 
+* Raw data input (on disk):
 
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
+    * PNG images 8 bit grayscale
+    * resolution: fixed to one of the cases above
+
+* Output image:
+
+    * Transforms:
+
+        * Load raw PNG with :py:mod:`PIL`
+        * Remove black borders
+        * Torch center cropping to get square image
+
+    * Final specifications
 
-    def setup(self, stage: str):
-        self.dataset = _maker("default")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
+        * Grayscale, encoded as a single plane image, 8 bits
+        * Square (4020x4020 px)
 
 
-datamodule = DefaultModule
+Protocol ``default``:
+
+    * Training samples: first 64% of TB and healthy CXR (including labels)
+    * Validation samples: 16% of TB and healthy CXR (including labels)
+    * Test samples: 20% of TB and healty CXR (including labels)
+"""
+
+import importlib.resources
+
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from .loader import RawDataLoader
+
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "default.json.bz2"
+        )
+    ),
+    raw_data_loader=RawDataLoader(),
+)
diff --git a/src/ptbench/data/montgomery/fold_0.py b/src/ptbench/data/montgomery/fold_0.py
index c60791be50ccd5186ce8e4af263efb7d7513b07a..e50d2e302f1c6b529c862c529bb77cf20aef8a57 100644
--- a/src/ptbench/data/montgomery/fold_0.py
+++ b/src/ptbench/data/montgomery/fold_0.py
@@ -2,46 +2,22 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Montgomery dataset for TB detection (cross validation fold 0)
+"""Montgomery datamodule for TB detection (``fold 0`` protocol)
 
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
+See :py:mod:`ptbench.data.montgomery.default` for input/output details.
 """
 
-from clapper.logging import setup
+import importlib.resources
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from .loader import RawDataLoader
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "fold_0.json.bz2"
         )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_0")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
+    ),
+    raw_data_loader=RawDataLoader(),
+)
diff --git a/src/ptbench/data/montgomery/fold_0_rgb.py b/src/ptbench/data/montgomery/fold_0_rgb.py
deleted file mode 100644
index 8e8b0c8914b6a63dd9ab854984ff2bc51cb4e255..0000000000000000000000000000000000000000
--- a/src/ptbench/data/montgomery/fold_0_rgb.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Montgomery dataset for TB detection (cross validation fold 0, RGB)
-
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_0", RGB=True)
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/montgomery/fold_1.py b/src/ptbench/data/montgomery/fold_1.py
index d6627e673978bcf960b8fb5f72add7cb4a13a141..3698a9edfa614f980b9b2352d97c7329965d371d 100644
--- a/src/ptbench/data/montgomery/fold_1.py
+++ b/src/ptbench/data/montgomery/fold_1.py
@@ -2,46 +2,22 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Montgomery dataset for TB detection (cross validation fold 1)
+"""Montgomery datamodule for TB detection (default protocol)
 
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
+See :py:mod:`ptbench.data.montgomery.default` for input/output details.
 """
 
-from clapper.logging import setup
+import importlib.resources
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from .loader import RawDataLoader
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "fold_1.json.bz2"
         )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_1")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
+    ),
+    raw_data_loader=RawDataLoader(),
+)
diff --git a/src/ptbench/data/montgomery/fold_1_rgb.py b/src/ptbench/data/montgomery/fold_1_rgb.py
deleted file mode 100644
index bc47a322c3fd779e3bc19924f6d7ac7c13e71847..0000000000000000000000000000000000000000
--- a/src/ptbench/data/montgomery/fold_1_rgb.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Montgomery dataset for TB detection (cross validation fold 1, RGB)
-
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_1", RGB=True)
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/montgomery/fold_2.py b/src/ptbench/data/montgomery/fold_2.py
index 8c5f4a66fd2af0b9f26b67241f45c630f69bd06a..b2d7ac2cfd8def5627b56d5353740e9676e1d9cc 100644
--- a/src/ptbench/data/montgomery/fold_2.py
+++ b/src/ptbench/data/montgomery/fold_2.py
@@ -2,46 +2,22 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Montgomery dataset for TB detection (cross validation fold 2)
+"""Montgomery datamodule for TB detection (default protocol)
 
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
+See :py:mod:`ptbench.data.montgomery.default` for input/output details.
 """
 
-from clapper.logging import setup
+import importlib.resources
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from .loader import RawDataLoader
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "fold_2.json.bz2"
         )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_2")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
+    ),
+    raw_data_loader=RawDataLoader(),
+)
diff --git a/src/ptbench/data/montgomery/fold_2_rgb.py b/src/ptbench/data/montgomery/fold_2_rgb.py
deleted file mode 100644
index b81a877b2bc7372a99812a27935e6daf42401568..0000000000000000000000000000000000000000
--- a/src/ptbench/data/montgomery/fold_2_rgb.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Montgomery dataset for TB detection (cross validation fold 2, RGB)
-
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_2", RGB=True)
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/montgomery/fold_3.py b/src/ptbench/data/montgomery/fold_3.py
index 8e685d7e3baa3a23924c62a77ffc61bf51e12056..1c566e4f528e587cfd8a3bd882e2c73ea5a46aa6 100644
--- a/src/ptbench/data/montgomery/fold_3.py
+++ b/src/ptbench/data/montgomery/fold_3.py
@@ -2,46 +2,22 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Montgomery dataset for TB detection (cross validation fold 3)
+"""Montgomery datamodule for TB detection (default protocol)
 
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
+See :py:mod:`ptbench.data.montgomery.default` for input/output details.
 """
 
-from clapper.logging import setup
+import importlib.resources
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from .loader import RawDataLoader
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "fold_3.json.bz2"
         )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_3")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
+    ),
+    raw_data_loader=RawDataLoader(),
+)
diff --git a/src/ptbench/data/montgomery/fold_3_rgb.py b/src/ptbench/data/montgomery/fold_3_rgb.py
deleted file mode 100644
index 7b600371c8d434d79049c6e6423b36e99f2a32cb..0000000000000000000000000000000000000000
--- a/src/ptbench/data/montgomery/fold_3_rgb.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Montgomery dataset for TB detection (cross validation fold 3, RGB)
-
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_3", RGB=True)
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/montgomery/fold_4.py b/src/ptbench/data/montgomery/fold_4.py
index 9459cb938605df06823a86a96fbd1cf374fe9738..4b68bd538f71115a01bae0fce87742be6ab711a8 100644
--- a/src/ptbench/data/montgomery/fold_4.py
+++ b/src/ptbench/data/montgomery/fold_4.py
@@ -2,46 +2,22 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Montgomery dataset for TB detection (cross validation fold 4)
+"""Montgomery datamodule for TB detection (default protocol)
 
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
+See :py:mod:`ptbench.data.montgomery.default` for input/output details.
 """
 
-from clapper.logging import setup
+import importlib.resources
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from .loader import RawDataLoader
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "fold_4.json.bz2"
         )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_4")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
+    ),
+    raw_data_loader=RawDataLoader(),
+)
diff --git a/src/ptbench/data/montgomery/fold_4_rgb.py b/src/ptbench/data/montgomery/fold_4_rgb.py
deleted file mode 100644
index 3eb136f654ab8d8d648468948e05dad774d85076..0000000000000000000000000000000000000000
--- a/src/ptbench/data/montgomery/fold_4_rgb.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Montgomery dataset for TB detection (cross validation fold 4, RGB)
-
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_4", RGB=True)
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/montgomery/fold_5.py b/src/ptbench/data/montgomery/fold_5.py
index 147690f6d54f15d50b52f88288dbc8a41dfb7f33..59891e8e1b5531b94fc996bfe25ef140ff39a83a 100644
--- a/src/ptbench/data/montgomery/fold_5.py
+++ b/src/ptbench/data/montgomery/fold_5.py
@@ -2,46 +2,22 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Montgomery dataset for TB detection (cross validation fold 5)
+"""Montgomery datamodule for TB detection (default protocol)
 
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
+See :py:mod:`ptbench.data.montgomery.default` for input/output details.
 """
 
-from clapper.logging import setup
+import importlib.resources
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from .loader import RawDataLoader
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "fold_5.json.bz2"
         )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_5")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
+    ),
+    raw_data_loader=RawDataLoader(),
+)
diff --git a/src/ptbench/data/montgomery/fold_5_rgb.py b/src/ptbench/data/montgomery/fold_5_rgb.py
deleted file mode 100644
index 3e7cb73f6957086b99147812b07f733dc51af9ec..0000000000000000000000000000000000000000
--- a/src/ptbench/data/montgomery/fold_5_rgb.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Montgomery dataset for TB detection (cross validation fold 5, RGB)
-
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_5", RGB=True)
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/montgomery/fold_6.py b/src/ptbench/data/montgomery/fold_6.py
index 69f24390ac01271c3e961950d429d973e535c380..e6c1d31a69ff20bbfd3ec4e53ba4eab0f9beec7f 100644
--- a/src/ptbench/data/montgomery/fold_6.py
+++ b/src/ptbench/data/montgomery/fold_6.py
@@ -2,46 +2,22 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Montgomery dataset for TB detection (cross validation fold 6)
+"""Montgomery datamodule for TB detection (default protocol)
 
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
+See :py:mod:`ptbench.data.montgomery.default` for input/output details.
 """
 
-from clapper.logging import setup
+import importlib.resources
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from .loader import RawDataLoader
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "fold_6.json.bz2"
         )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_6")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
+    ),
+    raw_data_loader=RawDataLoader(),
+)
diff --git a/src/ptbench/data/montgomery/fold_6_rgb.py b/src/ptbench/data/montgomery/fold_6_rgb.py
deleted file mode 100644
index ff3a8cdb0c00f511f4ebb7abcfabb10ae7853e99..0000000000000000000000000000000000000000
--- a/src/ptbench/data/montgomery/fold_6_rgb.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Montgomery dataset for TB detection (cross validation fold 6, RGB)
-
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_6", RGB=True)
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/montgomery/fold_7.py b/src/ptbench/data/montgomery/fold_7.py
index 20ba9d3a7da5ffcb8673e685a0534d82fdb7ed2b..44dd80512be61c32616188968a418b9963b41aed 100644
--- a/src/ptbench/data/montgomery/fold_7.py
+++ b/src/ptbench/data/montgomery/fold_7.py
@@ -2,46 +2,22 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Montgomery dataset for TB detection (cross validation fold 7)
+"""Montgomery datamodule for TB detection (default protocol)
 
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
+See :py:mod:`ptbench.data.montgomery.default` for input/output details.
 """
 
-from clapper.logging import setup
+import importlib.resources
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from .loader import RawDataLoader
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "fold_7.json.bz2"
         )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_7")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
+    ),
+    raw_data_loader=RawDataLoader(),
+)
diff --git a/src/ptbench/data/montgomery/fold_7_rgb.py b/src/ptbench/data/montgomery/fold_7_rgb.py
deleted file mode 100644
index 05664b06ab6393911a77b32418d6f2afb9d455fa..0000000000000000000000000000000000000000
--- a/src/ptbench/data/montgomery/fold_7_rgb.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Montgomery dataset for TB detection (cross validation fold 7, RGB)
-
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_7", RGB=True)
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/montgomery/fold_8.py b/src/ptbench/data/montgomery/fold_8.py
index e92ff959a9b1028c174c95719867f5086831d6c9..fd7edde69259023fa36ff05027fe1f0ad19d6661 100644
--- a/src/ptbench/data/montgomery/fold_8.py
+++ b/src/ptbench/data/montgomery/fold_8.py
@@ -2,46 +2,22 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Montgomery dataset for TB detection (cross validation fold 8)
+"""Montgomery datamodule for TB detection (default protocol)
 
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
+See :py:mod:`ptbench.data.montgomery.default` for input/output details.
 """
 
-from clapper.logging import setup
+import importlib.resources
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from .loader import RawDataLoader
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "fold_8.json.bz2"
         )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_8")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
+    ),
+    raw_data_loader=RawDataLoader(),
+)
diff --git a/src/ptbench/data/montgomery/fold_8_rgb.py b/src/ptbench/data/montgomery/fold_8_rgb.py
deleted file mode 100644
index b7d59359dcde32694affea0e3df88ad747f48e31..0000000000000000000000000000000000000000
--- a/src/ptbench/data/montgomery/fold_8_rgb.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Montgomery dataset for TB detection (cross validation fold 8, RGB)
-
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_8", RGB=True)
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/montgomery/fold_9.py b/src/ptbench/data/montgomery/fold_9.py
index 81bbf72e78826f7e9560189be149d51cb729064e..91228362f8c376d9ac9186f6675d80295e848f13 100644
--- a/src/ptbench/data/montgomery/fold_9.py
+++ b/src/ptbench/data/montgomery/fold_9.py
@@ -2,46 +2,22 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Montgomery dataset for TB detection (cross validation fold 9)
+"""Montgomery datamodule for TB detection (default protocol)
 
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
+See :py:mod:`ptbench.data.montgomery.default` for input/output details.
 """
 
-from clapper.logging import setup
+import importlib.resources
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from .loader import RawDataLoader
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "fold_9.json.bz2"
         )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_9")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
+    ),
+    raw_data_loader=RawDataLoader(),
+)
diff --git a/src/ptbench/data/montgomery/fold_9_rgb.py b/src/ptbench/data/montgomery/fold_9_rgb.py
deleted file mode 100644
index e961e08ffe49a94001252c641ba8bee86758b44f..0000000000000000000000000000000000000000
--- a/src/ptbench/data/montgomery/fold_9_rgb.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Montgomery dataset for TB detection (cross validation fold 9, RGB)
-
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold_9", RGB=True)
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/montgomery/loader.py b/src/ptbench/data/montgomery/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad856d5fcc45603015cf75c1c87885751f25bcd8
--- /dev/null
+++ b/src/ptbench/data/montgomery/loader.py
@@ -0,0 +1,87 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Specialized raw-data loaders for the Montgomery dataset."""
+
+import os
+
+import PIL.Image
+
+from torchvision.transforms.functional import center_crop, to_tensor
+
+from ...utils.rc import load_rc
+from ..image_utils import remove_black_borders
+from ..typing import RawDataLoader as _BaseRawDataLoader
+from ..typing import Sample
+
+
+class RawDataLoader(_BaseRawDataLoader):
+    """A specialized raw-data-loader for the Montgomery dataset.
+
+    Attributes
+    ----------
+
+    datadir
+        This variable contains the base directory where the database raw data
+        is stored.
+    """
+
+    datadir: str
+
+    def __init__(self):
+        self.datadir = load_rc().get(
+            "datadir.montgomery", os.path.realpath(os.curdir)
+        )
+
+    def sample(self, sample: tuple[str, int]) -> Sample:
+        """Loads a single image sample from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        sample
+            The sample representation
+        """
+        # N.B.: Montgomery images are encoded as grayscale PNGs, so no need to
+        # convert them again with Image.convert("L").
+        image = PIL.Image.open(os.path.join(self.datadir, sample[0]))
+        image = remove_black_borders(image)
+        tensor = to_tensor(image)
+        tensor = center_crop(tensor, min(*tensor.shape[1:]))
+
+        # use the code below to view generated images
+        # from torchvision.transforms.functional import to_pil_image
+        # to_pil_image(tensor).show()
+        # __import__("pdb").set_trace()
+
+        return tensor, dict(label=sample[1], name=sample[0])  # type: ignore[arg-type]
+
+    def label(self, sample: tuple[str, int]) -> int:
+        """Loads a single image sample label from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        label
+            The integer label associated with the sample
+        """
+        return sample[1]
diff --git a/src/ptbench/data/montgomery/rgb.py b/src/ptbench/data/montgomery/rgb.py
deleted file mode 100644
index c162126648f0baae5a921fa7f009da171fb8ccc7..0000000000000000000000000000000000000000
--- a/src/ptbench/data/montgomery/rgb.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Montgomery dataset for TB detection (default protocol, converted in RGB)
-
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.montgomery` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("default", RGB=True)
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
index bfe93f44faaa9df235f357c1cc3a927412f4a011..a163b9bc6290f53e611d214bdfa03e0cf93eb492 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/data/shenzhen/default.py
@@ -2,33 +2,42 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Shenzhen datamodule for computer-aided diagnosis (default protocol)
+"""Shenzhen datamodule for computer-aided diagnosis (``default`` protocol)
 
-See :py:mod:`ptbench.data.shenzhen` for more database details.
+The standard digital image database for Tuberculosis was created by the National
+Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s
+Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from
+out-patient clinics, and were captured as part of the daily routine using
+Philips DR Digital Diagnose systems.
 
-This configuration:
+* Database reference: [MONTGOMERY-SHENZHEN-2014]_
+* Original resolution (height x width or width x height): 3000 x 3000 or less
+* This split:
 
-* Raw data input (on disk):
+  * Split reference: None
+  * Training samples: 64% of TB and healthy CXR (including labels)
+  * Validation samples: 16% of TB and healthy CXR (including labels)
+  * Test samples: 20% of TB and healthy CXR (including labels)
+
+Data specifications:
 
-  * PNG images (black and white, encoded as color images)
-  * Variable width and height:
+* Raw data input (on disk):
 
-    * widths: from 1130 to 3001 pixels
-    * heights: from 948 to 3001 pixels
+    * PNG images (grayscale, encoded as RGB images with "inverted" grayscale scale)
+    * Variable width and height
 
 * Output image:
 
-  * Transforms:
+    * Transforms:
 
-    * Load raw PNG with :py:mod:`PIL`
-    * Remove black borders
-    * Torch resizing(512px, 512px)
-    * Torch center cropping (512px, 512px)
+        * Load raw PNG with :py:mod:`PIL`
+        * Remove black borders
+        * Torch center cropping to get square image
 
-  * Final specifications:
+    * Final specifications:
 
-    * Fixed resolution: 512x512 pixels
-    * Color RGB encoding
+        * Grayscale, encoded as a single plane image, 8 bits
+        * Square, with varying resolutions, depending on the input image
 """
 
 import importlib.resources
diff --git a/src/ptbench/data/shenzhen/fold_0.py b/src/ptbench/data/shenzhen/fold_0.py
index 888a0e60024480a3aaff65f6e3d819370fd22669..b505974491eea26e1da8931022eb168a42d57a0f 100644
--- a/src/ptbench/data/shenzhen/fold_0.py
+++ b/src/ptbench/data/shenzhen/fold_0.py
@@ -4,31 +4,7 @@
 
 """Shenzhen datamodule for computer-aided diagnosis (fold 0)
 
-See :py:mod:`ptbench.data.shenzhen` for more database details.
-
-This configuration:
-
-* Raw data input (on disk):
-
-  * PNG images (black and white, encoded as color images)
-  * Variable width and height:
-
-    * widths: from 1130 to 3001 pixels
-    * heights: from 948 to 3001 pixels
-
-* Output image:
-
-  * Transforms:
-
-    * Load raw PNG with :py:mod:`PIL`
-    * Remove black borders
-    * Torch resizing(512px, 512px)
-    * Torch center cropping (512px, 512px)
-
-  * Final specifications:
-
-    * Fixed resolution: 512x512 pixels
-    * Color RGB encoding
+See :py:mod:`ptbench.data.shenzhen.default` for input/output details.
 """
 
 import importlib.resources
diff --git a/src/ptbench/data/shenzhen/fold_1.py b/src/ptbench/data/shenzhen/fold_1.py
index 62d7fbd55c83ed746754cbc99dcc65fe48efbc6a..1041c3e4ef6d14942dadd4c680dc10fee0cfd17c 100644
--- a/src/ptbench/data/shenzhen/fold_1.py
+++ b/src/ptbench/data/shenzhen/fold_1.py
@@ -4,31 +4,7 @@
 
 """Shenzhen datamodule for computer-aided diagnosis (fold 1)
 
-See :py:mod:`ptbench.data.shenzhen` for more database details.
-
-This configuration:
-
-* Raw data input (on disk):
-
-  * PNG images (black and white, encoded as color images)
-  * Variable width and height:
-
-    * widths: from 1130 to 3001 pixels
-    * heights: from 948 to 3001 pixels
-
-* Output image:
-
-  * Transforms:
-
-    * Load raw PNG with :py:mod:`PIL`
-    * Remove black borders
-    * Torch resizing(512px, 512px)
-    * Torch center cropping (512px, 512px)
-
-  * Final specifications:
-
-    * Fixed resolution: 512x512 pixels
-    * Color RGB encoding
+See :py:mod:`ptbench.data.shenzhen.default` for input/output details.
 """
 
 import importlib.resources
diff --git a/src/ptbench/data/shenzhen/fold_2.py b/src/ptbench/data/shenzhen/fold_2.py
index b41284cd9d1c4a56c70eff715078f82213dabb3c..5026116a9cd75ac406f334682b38ce760104444d 100644
--- a/src/ptbench/data/shenzhen/fold_2.py
+++ b/src/ptbench/data/shenzhen/fold_2.py
@@ -4,31 +4,7 @@
 
 """Shenzhen datamodule for computer-aided diagnosis (fold 2)
 
-See :py:mod:`ptbench.data.shenzhen` for more database details.
-
-This configuration:
-
-* Raw data input (on disk):
-
-  * PNG images (black and white, encoded as color images)
-  * Variable width and height:
-
-    * widths: from 1130 to 3001 pixels
-    * heights: from 948 to 3001 pixels
-
-* Output image:
-
-  * Transforms:
-
-    * Load raw PNG with :py:mod:`PIL`
-    * Remove black borders
-    * Torch resizing(512px, 512px)
-    * Torch center cropping (512px, 512px)
-
-  * Final specifications:
-
-    * Fixed resolution: 512x512 pixels
-    * Color RGB encoding
+See :py:mod:`ptbench.data.shenzhen.default` for input/output details.
 """
 
 import importlib.resources
diff --git a/src/ptbench/data/shenzhen/fold_3.py b/src/ptbench/data/shenzhen/fold_3.py
index cca555064e9923433ef39f591b3e342365cf7afc..16c00157c5fa9fda38afc16614b75f2e766c33d5 100644
--- a/src/ptbench/data/shenzhen/fold_3.py
+++ b/src/ptbench/data/shenzhen/fold_3.py
@@ -4,31 +4,7 @@
 
 """Shenzhen datamodule for computer-aided diagnosis (fold 3)
 
-See :py:mod:`ptbench.data.shenzhen` for more database details.
-
-This configuration:
-
-* Raw data input (on disk):
-
-  * PNG images (black and white, encoded as color images)
-  * Variable width and height:
-
-    * widths: from 1130 to 3001 pixels
-    * heights: from 948 to 3001 pixels
-
-* Output image:
-
-  * Transforms:
-
-    * Load raw PNG with :py:mod:`PIL`
-    * Remove black borders
-    * Torch resizing(512px, 512px)
-    * Torch center cropping (512px, 512px)
-
-  * Final specifications:
-
-    * Fixed resolution: 512x512 pixels
-    * Color RGB encoding
+See :py:mod:`ptbench.data.shenzhen.default` for input/output details.
 """
 
 import importlib.resources
diff --git a/src/ptbench/data/shenzhen/fold_4.py b/src/ptbench/data/shenzhen/fold_4.py
index 897420076303e47406cc9efb3b6bf0d294ab3611..c0b0fdacdf90fdce168988057219923af73ad6a0 100644
--- a/src/ptbench/data/shenzhen/fold_4.py
+++ b/src/ptbench/data/shenzhen/fold_4.py
@@ -4,31 +4,7 @@
 
 """Shenzhen datamodule for computer-aided diagnosis (fold 4)
 
-See :py:mod:`ptbench.data.shenzhen` for more database details.
-
-This configuration:
-
-* Raw data input (on disk):
-
-  * PNG images (black and white, encoded as color images)
-  * Variable width and height:
-
-    * widths: from 1130 to 3001 pixels
-    * heights: from 948 to 3001 pixels
-
-* Output image:
-
-  * Transforms:
-
-    * Load raw PNG with :py:mod:`PIL`
-    * Remove black borders
-    * Torch resizing(512px, 512px)
-    * Torch center cropping (512px, 512px)
-
-  * Final specifications:
-
-    * Fixed resolution: 512x512 pixels
-    * Color RGB encoding
+See :py:mod:`ptbench.data.shenzhen.default` for input/output details.
 """
 
 import importlib.resources
diff --git a/src/ptbench/data/shenzhen/fold_5.py b/src/ptbench/data/shenzhen/fold_5.py
index c520399d98ead9eeb1e3bdcfbe4dc48393adcebc..0397955e25d1077af68b825b5ecbf0d8974499db 100644
--- a/src/ptbench/data/shenzhen/fold_5.py
+++ b/src/ptbench/data/shenzhen/fold_5.py
@@ -4,31 +4,7 @@
 
 """Shenzhen datamodule for computer-aided diagnosis (fold 5)
 
-See :py:mod:`ptbench.data.shenzhen` for more database details.
-
-This configuration:
-
-* Raw data input (on disk):
-
-  * PNG images (black and white, encoded as color images)
-  * Variable width and height:
-
-    * widths: from 1130 to 3001 pixels
-    * heights: from 948 to 3001 pixels
-
-* Output image:
-
-  * Transforms:
-
-    * Load raw PNG with :py:mod:`PIL`
-    * Remove black borders
-    * Torch resizing(512px, 512px)
-    * Torch center cropping (512px, 512px)
-
-  * Final specifications:
-
-    * Fixed resolution: 512x512 pixels
-    * Color RGB encoding
+See :py:mod:`ptbench.data.shenzhen.default` for input/output details.
 """
 
 import importlib.resources
diff --git a/src/ptbench/data/shenzhen/fold_6.py b/src/ptbench/data/shenzhen/fold_6.py
index a28f8fc5ca3e0ebd4b49fceaec99d3a2e94dd34c..145685ea96be63501a8afd771518b4b2f3f65c49 100644
--- a/src/ptbench/data/shenzhen/fold_6.py
+++ b/src/ptbench/data/shenzhen/fold_6.py
@@ -4,31 +4,7 @@
 
 """Shenzhen datamodule for computer-aided diagnosis (fold 6)
 
-See :py:mod:`ptbench.data.shenzhen` for more database details.
-
-This configuration:
-
-* Raw data input (on disk):
-
-  * PNG images (black and white, encoded as color images)
-  * Variable width and height:
-
-    * widths: from 1130 to 3001 pixels
-    * heights: from 948 to 3001 pixels
-
-* Output image:
-
-  * Transforms:
-
-    * Load raw PNG with :py:mod:`PIL`
-    * Remove black borders
-    * Torch resizing(512px, 512px)
-    * Torch center cropping (512px, 512px)
-
-  * Final specifications:
-
-    * Fixed resolution: 512x512 pixels
-    * Color RGB encoding
+See :py:mod:`ptbench.data.shenzhen.default` for input/output details.
 """
 
 import importlib.resources
diff --git a/src/ptbench/data/shenzhen/fold_7.py b/src/ptbench/data/shenzhen/fold_7.py
index b0ea7b4324334980a2e55e4496ac4ab6af705d17..5b8d74034a18e2637a9a193557571521722e93bc 100644
--- a/src/ptbench/data/shenzhen/fold_7.py
+++ b/src/ptbench/data/shenzhen/fold_7.py
@@ -4,31 +4,7 @@
 
 """Shenzhen datamodule for computer-aided diagnosis (fold 7)
 
-See :py:mod:`ptbench.data.shenzhen` for more database details.
-
-This configuration:
-
-* Raw data input (on disk):
-
-  * PNG images (black and white, encoded as color images)
-  * Variable width and height:
-
-    * widths: from 1130 to 3001 pixels
-    * heights: from 948 to 3001 pixels
-
-* Output image:
-
-  * Transforms:
-
-    * Load raw PNG with :py:mod:`PIL`
-    * Remove black borders
-    * Torch resizing(512px, 512px)
-    * Torch center cropping (512px, 512px)
-
-  * Final specifications:
-
-    * Fixed resolution: 512x512 pixels
-    * Color RGB encoding
+See :py:mod:`ptbench.data.shenzhen.default` for input/output details.
 """
 
 import importlib.resources
diff --git a/src/ptbench/data/shenzhen/fold_8.py b/src/ptbench/data/shenzhen/fold_8.py
index 9bbfbe84ab942cf5da5a8c5fc8318724908998f9..e9ce1a2f408543bc93d8f116a2f8834ab79c989f 100644
--- a/src/ptbench/data/shenzhen/fold_8.py
+++ b/src/ptbench/data/shenzhen/fold_8.py
@@ -4,31 +4,7 @@
 
 """Shenzhen datamodule for computer-aided diagnosis (fold 8)
 
-See :py:mod:`ptbench.data.shenzhen` for more database details.
-
-This configuration:
-
-* Raw data input (on disk):
-
-  * PNG images (black and white, encoded as color images)
-  * Variable width and height:
-
-    * widths: from 1130 to 3001 pixels
-    * heights: from 948 to 3001 pixels
-
-* Output image:
-
-  * Transforms:
-
-    * Load raw PNG with :py:mod:`PIL`
-    * Remove black borders
-    * Torch resizing(512px, 512px)
-    * Torch center cropping (512px, 512px)
-
-  * Final specifications:
-
-    * Fixed resolution: 512x512 pixels
-    * Color RGB encoding
+See :py:mod:`ptbench.data.shenzhen.default` for input/output details.
 """
 
 import importlib.resources
diff --git a/src/ptbench/data/shenzhen/fold_9.py b/src/ptbench/data/shenzhen/fold_9.py
index 87c2afb328f9b09f420a1ddce5f5d0ea54346c43..6da8dd3d7a4260e7b9a478baea4b2848383f8459 100644
--- a/src/ptbench/data/shenzhen/fold_9.py
+++ b/src/ptbench/data/shenzhen/fold_9.py
@@ -4,31 +4,7 @@
 
 """Shenzhen datamodule for computer-aided diagnosis (fold 9)
 
-See :py:mod:`ptbench.data.shenzhen` for more database details.
-
-This configuration:
-
-* Raw data input (on disk):
-
-  * PNG images (black and white, encoded as color images)
-  * Variable width and height:
-
-    * widths: from 1130 to 3001 pixels
-    * heights: from 948 to 3001 pixels
-
-* Output image:
-
-  * Transforms:
-
-    * Load raw PNG with :py:mod:`PIL`
-    * Remove black borders
-    * Torch resizing(512px, 512px)
-    * Torch center cropping (512px, 512px)
-
-  * Final specifications:
-
-    * Fixed resolution: 512x512 pixels
-    * Color RGB encoding
+See :py:mod:`ptbench.data.shenzhen.default` for input/output details.
 """
 
 import importlib.resources
diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/loader.py
index 49ccf8bfb217e411004228b4acf7c924e3ffec66..3409fed2e1a552c44135888df6d6bc4a874b427c 100644
--- a/src/ptbench/data/shenzhen/loader.py
+++ b/src/ptbench/data/shenzhen/loader.py
@@ -2,30 +2,16 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Shenzhen dataset for computer-aided diagnosis.
-
-The standard digital image database for Tuberculosis is created by the National
-Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s
-Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from
-out-patient clinics, and were captured as part of the daily routine using
-Philips DR Digital Diagnose systems.
-
-* Reference: [MONTGOMERY-SHENZHEN-2014]_
-* Original resolution (height x width or width x height): 3000 x 3000 or less
-* Split reference: none
-* Protocol ``default``:
-
-  * Training samples: 64% of TB and healthy CXR (including labels)
-  * Validation samples: 16% of TB and healthy CXR (including labels)
-  * Test samples: 20% of TB and healthy CXR (including labels)
-"""
+"""Specialized raw-data loaders for the Shenzen dataset."""
 
 import os
 
-import torchvision.transforms
+import PIL.Image
+
+from torchvision.transforms.functional import center_crop, to_tensor
 
 from ...utils.rc import load_rc
-from ..image_utils import RemoveBlackBorders, load_pil_baw
+from ..image_utils import remove_black_borders
 from ..typing import RawDataLoader as _BaseRawDataLoader
 from ..typing import Sample
 
@@ -45,22 +31,12 @@ class RawDataLoader(_BaseRawDataLoader):
     """
 
     datadir: str
-    transform: torchvision.transforms.Compose
 
     def __init__(self):
         self.datadir = load_rc().get(
             "datadir.shenzhen", os.path.realpath(os.curdir)
         )
 
-        self.transform = torchvision.transforms.Compose(
-            [
-                RemoveBlackBorders(),
-                torchvision.transforms.Resize(512),
-                torchvision.transforms.CenterCrop(512),
-                torchvision.transforms.ToTensor(),
-            ]
-        )
-
     def sample(self, sample: tuple[str, int]) -> Sample:
         """Loads a single image sample from the disk.
 
@@ -79,9 +55,19 @@ class RawDataLoader(_BaseRawDataLoader):
         sample
             The sample representation
         """
-        tensor = self.transform(
-            load_pil_baw(os.path.join(self.datadir, sample[0]))
+        # N.B.: Image.convert("L") is required to normalize grayscale back to
+        # normal (instead of inverted).
+        image = PIL.Image.open(os.path.join(self.datadir, sample[0])).convert(
+            "L"
         )
+        image = remove_black_borders(image)
+        tensor = to_tensor(image)
+        tensor = center_crop(tensor, min(*tensor.shape[1:]))
+
+        # use the code below to view generated images
+        # from torchvision.transforms.functional import to_pil_image
+        # to_pil_image(tensor).show()
+        # __import__("pdb").set_trace()
 
         return tensor, dict(label=sample[1], name=sample[0])  # type: ignore[arg-type]
 
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index e2cb9b053a57c6c3a786c19c4cb47a99e8d5cddb..650aa7d4f60dfe5feda914deb8518415fda8876d 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -73,7 +73,8 @@ class Pasa(pl.LightningModule):
         self.name = "pasa"
 
         self.model_transforms = [
-            torchvision.transforms.Resize(512),
+            torchvision.transforms.Grayscale(),
+            torchvision.transforms.Resize(512, antialias=True),
         ]
 
         self._train_loss = train_loss
diff --git a/tests/test_ch.py b/tests/test_ch.py
index 659e2c35ae092f3a90f7d072ba033786bb80bdf9..b28c81e93ef0765c59dd7252b93ecdecaeff6946 100644
--- a/tests/test_ch.py
+++ b/tests/test_ch.py
@@ -120,11 +120,6 @@ def test_loading():
 
     from ptbench.data.datamodule import _DelayedLoadingDataset
 
-    def _check_size(shape):
-        if shape[0] == 1 and shape[1] == 512 and shape[2] == 512:
-            return True
-        return False
-
     def _check_sample(s):
         assert len(s) == 2
 
@@ -132,10 +127,12 @@ def test_loading():
         metadata = s[1]
 
         assert isinstance(data, torch.Tensor)
-        assert _check_size(data.shape)  # Check size
+
+        assert data.size(0) == 3 # check 3 channels
+        assert data.size(1) == data.size(2) # check square image
 
         assert (
-            torchvision.transforms.ToPILImage()(data).mode == "L"
+            torchvision.transforms.ToPILImage()(data).mode == "RGB"
         )  # Check colors
 
         assert "label" in metadata
diff --git a/tests/test_mc.py b/tests/test_mc.py
index 1b2aa4fd5a0317b939816bf625907c534ead7910..2fcd14ac131f5d919e954bd4d9562effd4a19296 100644
--- a/tests/test_mc.py
+++ b/tests/test_mc.py
@@ -4,131 +4,188 @@
 
 """Tests for Montgomery dataset."""
 
+import importlib
+
 import pytest
 
 
 def test_protocol_consistency():
-    from ptbench.data.montgomery import dataset
 
     # Default protocol
-    subset = dataset.subsets("default")
+    datamodule = importlib.import_module(
+        "ptbench.data.montgomery.default"
+    ).datamodule
+    subset = datamodule.database_split.subsets
+
     assert len(subset) == 3
 
     assert "train" in subset
     assert len(subset["train"]) == 88
     for s in subset["train"]:
-        assert s.key.startswith("CXR_png/MCUCXR_0")
+        assert s[0].startswith("CXR_png/MCUCXR_0")
 
     assert "validation" in subset
     assert len(subset["validation"]) == 22
     for s in subset["validation"]:
-        assert s.key.startswith("CXR_png/MCUCXR_0")
+        assert s[0].startswith("CXR_png/MCUCXR_0")
 
     assert "test" in subset
     assert len(subset["test"]) == 28
     for s in subset["test"]:
-        assert s.key.startswith("CXR_png/MCUCXR_0")
+        assert s[0].startswith("CXR_png/MCUCXR_0")
 
     # Check labels
     for s in subset["train"]:
-        assert s.label in [0.0, 1.0]
+        assert s[1] in [0.0, 1.0]
 
     for s in subset["validation"]:
-        assert s.label in [0.0, 1.0]
+        assert s[1] in [0.0, 1.0]
 
     for s in subset["test"]:
-        assert s.label in [0.0, 1.0]
+        assert s[1] in [0.0, 1.0]
 
     # Cross-validation fold 0-7
     for f in range(8):
-        subset = dataset.subsets("fold_" + str(f))
+        datamodule = importlib.import_module(
+            f"ptbench.data.montgomery.fold_{str(f)}"
+        ).datamodule
+        subset = datamodule.database_split.subsets
+
         assert len(subset) == 3
 
         assert "train" in subset
         assert len(subset["train"]) == 99
         for s in subset["train"]:
-            assert s.key.startswith("CXR_png/MCUCXR_0")
+            assert s[0].startswith("CXR_png/MCUCXR_0")
 
         assert "validation" in subset
         assert len(subset["validation"]) == 25
         for s in subset["validation"]:
-            assert s.key.startswith("CXR_png/MCUCXR_0")
+            assert s[0].startswith("CXR_png/MCUCXR_0")
 
         assert "test" in subset
         assert len(subset["test"]) == 14
         for s in subset["test"]:
-            assert s.key.startswith("CXR_png/MCUCXR_0")
+            assert s[0].startswith("CXR_png/MCUCXR_0")
 
         # Check labels
         for s in subset["train"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
         for s in subset["validation"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
         for s in subset["test"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
     # Cross-validation fold 8-9
     for f in range(8, 10):
-        subset = dataset.subsets("fold_" + str(f))
+        datamodule = importlib.import_module(
+            f"ptbench.data.montgomery.fold_{str(f)}"
+        ).datamodule
+        subset = datamodule.database_split.subsets
+
         assert len(subset) == 3
 
         assert "train" in subset
         assert len(subset["train"]) == 100
         for s in subset["train"]:
-            assert s.key.startswith("CXR_png/MCUCXR_0")
+            assert s[0].startswith("CXR_png/MCUCXR_0")
 
         assert "validation" in subset
         assert len(subset["validation"]) == 25
         for s in subset["validation"]:
-            assert s.key.startswith("CXR_png/MCUCXR_0")
+            assert s[0].startswith("CXR_png/MCUCXR_0")
 
         assert "test" in subset
         assert len(subset["test"]) == 13
         for s in subset["test"]:
-            assert s.key.startswith("CXR_png/MCUCXR_0")
+            assert s[0].startswith("CXR_png/MCUCXR_0")
 
         # Check labels
         for s in subset["train"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
         for s in subset["validation"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
         for s in subset["test"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
 
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
 def test_loading():
-    from ptbench.data.montgomery import dataset
+    import torch
+    import torchvision.transforms
+
+    from ptbench.data.datamodule import _DelayedLoadingDataset
 
     def _check_sample(s):
-        data = s.data
-        assert isinstance(data, dict)
-        assert len(data) == 2
-
-        assert "data" in data
-        assert data["data"].size in (
-            (4020, 4892),  # portrait
-            (4892, 4020),  # landscape
-            (512, 512),  # test database @ CI
-        )
-        assert data["data"].mode == "L"  # Check colors
+        data = s[0]
+        metadata = s[1]
+
+        assert isinstance(data, torch.Tensor)
+
+        assert data.size(0) == 1 # check single channel
+        assert data.size(1) == data.size(2) # check square image
 
-        assert "label" in data
-        assert data["label"] in [0, 1]  # Check labels
+        assert (
+            torchvision.transforms.ToPILImage()(data).mode == "L" 
+        ) # Check colors
+
+        assert "label" in metadata
+        assert metadata["label"] in [0, 1]  # Check labels
 
     limit = 30  # use this to limit testing to first images only, else None
 
-    subset = dataset.subsets("default")
-    for s in subset["train"][:limit]:
+    datamodule = importlib.import_module(
+        "ptbench.data.montgomery.default"
+    ).datamodule
+    subset = datamodule.database_split.subsets
+    raw_data_loader = datamodule.raw_data_loader
+
+    # Need to use private function so we can limit the number of samples to use
+    dataset = _DelayedLoadingDataset(
+        subset["train"][:limit],
+        raw_data_loader
+    )
+
+    for s in dataset:
         _check_sample(s)
 
 
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
 def test_check():
-    from ptbench.data.montgomery import dataset
+    from ptbench.data.split import check_database_split_loading
+
+    limit = 30  # use this to limit testing to first images only, else 0
+
+    # Default protocol
+    datamodule = importlib.import_module(
+        "ptbench.data.montgomery.default"
+    ).datamodule
+    database_split = datamodule.database_split
+    raw_data_loader = datamodule.raw_data_loader
+
+    assert (
+        check_database_split_loading(
+            database_split, raw_data_loader, limit=limit
+        )
+        == 0
+    )
+
+    # Folds
+    for f in range(10):
+        datamodule = importlib.import_module(
+            f"ptbench.data.montgomery.fold_{f}"
+        ).datamodule
+        database_split = datamodule.database_split
+        raw_data_loader = datamodule.raw_data_loader
+
+        assert (
+            check_database_split_loading(
+                database_split, raw_data_loader, limit=limit
+            )
+            == 0
+        )
 
-    assert dataset.check() == 0