diff --git a/src/mednet/libs/common/data/datamodule.py b/src/mednet/libs/common/data/datamodule.py
index f0ed67b54431a9be5675b8aa3493752eb1dd1a6c..c2e9bea88026b513fa9ef741cf6313f71622fba1 100644
--- a/src/mednet/libs/common/data/datamodule.py
+++ b/src/mednet/libs/common/data/datamodule.py
@@ -15,7 +15,6 @@ import torch.backends
 import torch.utils.data
 import torchvision.transforms
 import tqdm
-from torchvision import tv_tensors
 
 from .typing import (
     ConcatDatabaseSplit,
@@ -31,6 +30,23 @@ from .typing import (
 logger = logging.getLogger(__name__)
 
 
+def _tensor_size_bytes(t: torch.Tensor) -> int:
+    """Return a tensor size in bytes.
+
+    Parameters
+    ----------
+    t
+        A torch Tensor.
+
+    Returns
+    -------
+    int
+        The size of the Tensor in bytes.
+    """
+
+    return int(t.element_size() * torch.prod(torch.tensor(t.shape)))
+
+
 def _sample_size_bytes(s: Sample) -> int:
     """Recurse into the sample and figures out its total occupance in bytes.
 
@@ -45,6 +61,33 @@ def _sample_size_bytes(s: Sample) -> int:
         The size in bytes occupied by this sample.
     """
 
+    size = sys.getsizeof(s[0])  # tensor metadata
+    size += int(s[0].element_size() * torch.prod(torch.tensor(s[0].shape)))
+    size += sys.getsizeof(s[1])
+
+    # check each element - if it is a tensor, then adds its total space in
+    # bytes
+    for v in s[1].values():
+        if isinstance(v, torch.Tensor):
+            size += _tensor_size_bytes(v)
+
+    return size
+
+
+def _sample_dict_size_bytes(s: Sample) -> int:
+    """Recurse into the sample and figures out its total occupance in bytes.
+
+    Parameters
+    ----------
+    s
+        The sample to be analyzed.
+
+    Returns
+    -------
+    int
+        The size in bytes occupied by this sample.
+    """
+
     def _tensor_size_bytes(t: torch.Tensor) -> int:
         """Return a tensor size in bytes.
 
@@ -62,42 +105,68 @@ def _sample_size_bytes(s: Sample) -> int:
         return int(t.element_size() * torch.prod(torch.tensor(t.shape)))
 
     size = sys.getsizeof(s[0])  # tensor metadata
-    size += int(s[0].element_size() * torch.prod(torch.tensor(s[0].shape)))
     size += sys.getsizeof(s[1])
 
     # check each element - if it is a tensor, then adds its total space in
     # bytes
-    for v in s[1].values():
-        if isinstance(v, torch.Tensor):
-            size += _tensor_size_bytes(v)
+    for s_ in s:
+        for v in s_.values():
+            if isinstance(v, torch.Tensor):
+                size += _tensor_size_bytes(v)
 
     return size
 
 
-def transform_tvtensors(data_dict, transforms):
-    """Apply transforms on dictionary elements which are of TVTensor type.
+def _estimate_data_footprint(dataset):
+    """Compute the estimated memory required to load samples in memory.
 
     Parameters
     ----------
-    data_dict
-        Dictionary possibly containing TVTensor elements to apply transforms to.
+    dataset
+        The dataset containing the samples to load.
+    """
+    first_sample = dataset[0]
+
+    logger.info("Delayed loading dataset (first tensor):")
+    if isinstance(first_sample[0], dict):
+        for k, v in first_sample[0].items():
+            logger.info(f"{k}: {list(v.shape)}@{v.dtype}")
+        sample_size_mb = _sample_dict_size_bytes(first_sample) / (1024.0 * 1024.0)
+        logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
+
+    else:
+        logger.info(f"{list(first_sample[0].shape)}@{first_sample[0].dtype}")
+        sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0)
+        logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
+
+
+def transform_tensors(data, transforms):
+    """Apply transforms on dictionary elements which are of Tensor type.
+
+    Parameters
+    ----------
+    data
+        Dictionary possibly containing Tensor elements to apply transforms to.
     transforms
         The transforms to apply.
 
     Returns
     -------
-        A dictionary with transforms applied on TVTensor elements.
+        A dictionary with transforms applied on Tensor elements.
     """
 
-    transformed_dict = {}
+    if isinstance(data, dict):
+        transformed_dict = {}
 
-    for k, m in data_dict.items():
-        if isinstance(m, tv_tensors.TVTensor):
-            transformed_dict[k] = transforms(m)
-        else:
-            transformed_dict[k] = m
+        for k, m in data.items():
+            if isinstance(m, torch.Tensor):
+                transformed_dict[k] = transforms(m)
+            else:
+                transformed_dict[k] = m
+
+        return transformed_dict
 
-    return transformed_dict
+    return transforms(data)
 
 
 class _DelayedLoadingDataset(Dataset):
@@ -128,18 +197,11 @@ class _DelayedLoadingDataset(Dataset):
         self.loader = loader
         self.transform = torchvision.transforms.Compose(transforms)
 
-        # Tests loading and output tensor size
-        first_sample = self[0]
-        logger.info(
-            f"Delayed loading dataset (first tensor): "
-            f"{list(first_sample[0].shape)}@{first_sample[0].dtype}",
-        )
-        sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0)
-        logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
+        _estimate_data_footprint(self)
 
     def __getitem__(self, key: int) -> Sample:
         tensor, metadata = self.loader.sample(self.raw_dataset[key])
-        return self.transform(tensor), transform_tvtensors(metadata, self.transform)
+        return transform_tensors(tensor, self.transform), metadata
 
     def __len__(self):
         return len(self.raw_dataset)
@@ -175,7 +237,7 @@ def _apply_loader_and_transforms(
     """
 
     sample = load(info)
-    return model_transform(sample[0]), transform_tvtensors(sample[1], model_transform)
+    return transform_tensors(sample[0], model_transform), sample[1]
 
 
 class _CachedDataset(Dataset):
@@ -229,16 +291,7 @@ class _CachedDataset(Dataset):
                     ),
                 )
 
-        # Estimates memory occupance
-        logger.info(
-            f"Cached dataset (first tensor): "
-            f"{list(self.data[0][0].shape)}@{self.data[0][0].dtype}",
-        )
-        sample_size_mb = _sample_size_bytes(self.data[0]) / (1024.0 * 1024.0)
-        logger.info(
-            f"Estimated RAM occupance (sample / dataset): "
-            f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb",
-        )
+        _estimate_data_footprint(self)
 
     def targets(self) -> list[int | list[int]]:
         """Return the integer targets for all samples in the dataset.
diff --git a/src/mednet/libs/common/engine/callbacks.py b/src/mednet/libs/common/engine/callbacks.py
index 3d108a8d2c372478d7792de8d65a4b21e2ce3428..b030c8820f6a7ec748733ab9ba692f35f48f7d8f 100644
--- a/src/mednet/libs/common/engine/callbacks.py
+++ b/src/mednet/libs/common/engine/callbacks.py
@@ -195,13 +195,18 @@ class LoggingCallback(lightning.pytorch.Callback):
             The relative number of the batch.
         """
 
+        if isinstance(batch[0], dict):
+            batch_size = batch[0]["image"].shape[0]
+        else:
+            batch_size = batch[0].shape[0]
+
         pl_module.log(
             "loss/train",
             outputs["loss"].item(),
             prog_bar=True,
             on_step=False,
             on_epoch=True,
-            batch_size=batch[0].shape[0],
+            batch_size=batch_size,
         )
 
     def on_validation_epoch_start(
@@ -367,12 +372,17 @@ class LoggingCallback(lightning.pytorch.Callback):
         else:
             key = f"loss/validation-{dataloader_idx}"
 
+        if isinstance(batch[0], dict):
+            batch_size = batch[0]["image"].shape[0]
+        else:
+            batch_size = batch[0].shape[0]
+
         pl_module.log(
             key,
             outputs.item(),
             prog_bar=False,
             on_step=False,
             on_epoch=True,
-            batch_size=batch[0].shape[0],
+            batch_size=batch_size,
             add_dataloader_idx=False,
         )
diff --git a/src/mednet/libs/segmentation/models/normalizer.py b/src/mednet/libs/segmentation/models/normalizer.py
index df630bd1d61c9b7e10a951a80e8f560248f9a477..5b3b7561c292f7520d4e37cc5fe5e2e4f3be9836 100644
--- a/src/mednet/libs/segmentation/models/normalizer.py
+++ b/src/mednet/libs/segmentation/models/normalizer.py
@@ -44,7 +44,7 @@ def make_z_normalizer(
 
     # Evaluates mean and standard deviation
     for batch in tqdm.tqdm(dataloader, unit="batch"):
-        data = batch[0]
+        data = batch[0]["image"]
         data = data.view(data.size(0), data.size(1), -1)
 
         num_images += data.size(0)