diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index c8c33dec3e0f417c838da720cc630bc04a59bd96..4c92476360b9895016560976706763f1098c74d3 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -3,6 +3,7 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 
 import collections
+import functools
 import logging
 import multiprocessing
 import sys
@@ -27,6 +28,39 @@ from .typing import (
 logger = logging.getLogger(__name__)
 
 
+def _sample_size_bytes(s: Sample) -> int:
+    """Recurse into the sample and figures out its total occupance in bytes.
+
+    Parameters
+    ----------
+
+    s
+        The sample to be analyzed
+
+
+    Returns
+    -------
+
+    size
+        The size in bytes occupied by this sample
+    """
+
+    def _tensor_size_bytes(t: torch.Tensor) -> int:
+        """Returns a tensor size in bytes."""
+        return int(t.element_size() * torch.prod(torch.tensor(t.shape)))
+
+    size = int(s[0].element_size() * torch.prod(torch.tensor(s[0].shape)))
+    size += sys.getsizeof(s[1])
+
+    # check each element - if it is a tensor, then adds its total space in
+    # bytes
+    for v in s[1].values():
+        if isinstance(v, torch.Tensor):
+            size += _tensor_size_bytes(v)
+
+    return size
+
+
 class _DelayedLoadingDataset(Dataset):
     """A list that loads its samples on demand.
 
@@ -59,6 +93,15 @@ class _DelayedLoadingDataset(Dataset):
         self.loader = loader
         self.transform = torchvision.transforms.Compose(transforms)
 
+        # Tests loading and output tensor size
+        first_sample = self[0]
+        logger.info(
+            f"Delayed loading dataset (first tensor): "
+            f"{list(first_sample[0].shape)}@{first_sample[0].dtype}"
+        )
+        sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0)
+        logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
+
     def labels(self) -> list[int]:
         """Returns the integer labels for all samples in the dataset."""
         return [self.loader.label(k) for k in self.split]
@@ -75,6 +118,39 @@ class _DelayedLoadingDataset(Dataset):
             yield self[x]
 
 
+def _apply_loader_and_transforms(
+    info: typing.Any,
+    load: typing.Callable[[typing.Any], Sample],
+    model_transform: typing.Callable[[torch.Tensor], torch.Tensor],
+) -> Sample:
+    """Local wrapper to apply raw-data loading and transformation in a single
+    step.
+
+    Parameters
+    ----------
+
+    info
+        The sample information, as loaded from its split dictionary
+
+    load
+        The raw-data loader function to use for loading the sample
+
+    model_transform
+        A callable that will transform the loaded tensor into something
+        suitable for the model it will train.  Typically, this will be a
+        composed transform.
+
+
+    Returns
+    -------
+
+    sample
+        The loaded and transformed sample.
+    """
+    sample = load(info)
+    return model_transform(sample[0]), sample[1]
+
+
 class _CachedDataset(Dataset):
     """Basically, a list of preloaded samples.
 
@@ -112,27 +188,41 @@ class _CachedDataset(Dataset):
         parallel: int = -1,
         transforms: typing.Sequence[Transform] = [],
     ):
-        self.transform = torchvision.transforms.Compose(transforms)
+        self.loader = functools.partial(
+            _apply_loader_and_transforms,
+            load=loader.sample,
+            model_transform=torchvision.transforms.Compose(transforms),
+        )
 
         if parallel < 0:
             self.data = [
-                loader.sample(k) for k in tqdm.tqdm(split, unit="sample")
+                self.loader(k) for k in tqdm.tqdm(split, unit="sample")
             ]
         else:
             instances = parallel or multiprocessing.cpu_count()
             logger.info(f"Caching dataset using {instances} processes...")
             with multiprocessing.Pool(instances) as p:
                 self.data = list(
-                    tqdm.tqdm(p.imap(loader.sample, split), total=len(split))
+                    tqdm.tqdm(p.imap(self.loader, split), total=len(split))
                 )
 
+        # Estimates memory occupance
+        logger.info(
+            f"Cached dataset (first tensor): "
+            f"{list(self.data[0][0].shape)}@{self.data[0][0].dtype}"
+        )
+        sample_size_mb = _sample_size_bytes(self.data[0]) / (1024.0 * 1024.0)
+        logger.info(
+            f"Estimated RAM occupance (sample / dataset): "
+            f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb"
+        )
+
     def labels(self) -> list[int]:
         """Returns the integer labels for all samples in the dataset."""
         return [k[1]["label"] for k in self.data]
 
     def __getitem__(self, key: int) -> Sample:
-        tensor, metadata = self.data[key]
-        return self.transform(tensor), metadata
+        return self.data[key]
 
     def __len__(self):
         return len(self.data)