diff --git a/src/mednet/libs/common/data/datamodule.py b/src/mednet/libs/common/data/datamodule.py
index c2e9bea88026b513fa9ef741cf6313f71622fba1..c8ffcffed7ec2c5fc7265d6a509f355e5e25f37c 100644
--- a/src/mednet/libs/common/data/datamodule.py
+++ b/src/mednet/libs/common/data/datamodule.py
@@ -30,62 +30,13 @@ from .typing import (
 logger = logging.getLogger(__name__)
 
 
-def _tensor_size_bytes(t: torch.Tensor) -> int:
-    """Return a tensor size in bytes.
+def _sample_size_bytes(dataset: Sample):
+    """Recurse into the first sample of a dataset and figures out its total occupance in bytes.
 
     Parameters
     ----------
-    t
-        A torch Tensor.
-
-    Returns
-    -------
-    int
-        The size of the Tensor in bytes.
-    """
-
-    return int(t.element_size() * torch.prod(torch.tensor(t.shape)))
-
-
-def _sample_size_bytes(s: Sample) -> int:
-    """Recurse into the sample and figures out its total occupance in bytes.
-
-    Parameters
-    ----------
-    s
-        The sample to be analyzed.
-
-    Returns
-    -------
-    int
-        The size in bytes occupied by this sample.
-    """
-
-    size = sys.getsizeof(s[0])  # tensor metadata
-    size += int(s[0].element_size() * torch.prod(torch.tensor(s[0].shape)))
-    size += sys.getsizeof(s[1])
-
-    # check each element - if it is a tensor, then adds its total space in
-    # bytes
-    for v in s[1].values():
-        if isinstance(v, torch.Tensor):
-            size += _tensor_size_bytes(v)
-
-    return size
-
-
-def _sample_dict_size_bytes(s: Sample) -> int:
-    """Recurse into the sample and figures out its total occupance in bytes.
-
-    Parameters
-    ----------
-    s
-        The sample to be analyzed.
-
-    Returns
-    -------
-    int
-        The size in bytes occupied by this sample.
+    dataset
+        The dataset containing the samples to load.
     """
 
     def _tensor_size_bytes(t: torch.Tensor) -> int:
@@ -102,42 +53,43 @@ def _sample_dict_size_bytes(s: Sample) -> int:
             The size of the Tensor in bytes.
         """
 
+        logger.info(f"{list(t.shape)}@{t.dtype}")
         return int(t.element_size() * torch.prod(torch.tensor(t.shape)))
 
-    size = sys.getsizeof(s[0])  # tensor metadata
-    size += sys.getsizeof(s[1])
+    def _dict_size_bytes(d):
+        """Return a dictionary size in bytes.
 
-    # check each element - if it is a tensor, then adds its total space in
-    # bytes
-    for s_ in s:
-        for v in s_.values():
+        Parameters
+        ----------
+        d
+            A dictionary.
+
+        Returns
+        -------
+        int
+            The size of the dictionary in bytes.
+        """
+
+        size = 0
+        for v in d.values():
             if isinstance(v, torch.Tensor):
                 size += _tensor_size_bytes(v)
 
-    return size
+        return size
 
+    size = 0
 
-def _estimate_data_footprint(dataset):
-    """Compute the estimated memory required to load samples in memory.
-
-    Parameters
-    ----------
-    dataset
-        The dataset containing the samples to load.
-    """
     first_sample = dataset[0]
 
-    logger.info("Delayed loading dataset (first tensor):")
-    if isinstance(first_sample[0], dict):
-        for k, v in first_sample[0].items():
-            logger.info(f"{k}: {list(v.shape)}@{v.dtype}")
-        sample_size_mb = _sample_dict_size_bytes(first_sample) / (1024.0 * 1024.0)
-        logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
+    for s in first_sample:
+        size += sys.getsizeof(s)
+        if isinstance(s, dict):
+            size += _dict_size_bytes(s)
+        else:
+            size += _tensor_size_bytes(s)
 
-    else:
-        logger.info(f"{list(first_sample[0].shape)}@{first_sample[0].dtype}")
-        sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0)
-        logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
+    sample_size_mb = size / (1024.0 * 1024.0)
+    logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
 
 
 def transform_tensors(data, transforms):
@@ -197,7 +149,7 @@ class _DelayedLoadingDataset(Dataset):
         self.loader = loader
         self.transform = torchvision.transforms.Compose(transforms)
 
-        _estimate_data_footprint(self)
+        _sample_size_bytes(self)
 
     def __getitem__(self, key: int) -> Sample:
         tensor, metadata = self.loader.sample(self.raw_dataset[key])
@@ -291,7 +243,7 @@ class _CachedDataset(Dataset):
                     ),
                 )
 
-        _estimate_data_footprint(self)
+        _sample_size_bytes(self)
 
     def targets(self) -> list[int | list[int]]:
         """Return the integer targets for all samples in the dataset.