diff --git a/src/mednet/libs/segmentation/config/data/avdrive/datamodule.py b/src/mednet/libs/segmentation/config/data/avdrive/datamodule.py
index be6adce106c70f54cc1dabfa7077c1789f7fe0cc..83156a2c03dbf2fadf9054b9d32d6c31fe0502f8 100644
--- a/src/mednet/libs/segmentation/config/data/avdrive/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/avdrive/datamodule.py
@@ -66,11 +66,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             )
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py b/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py
index 3dc63ed535f319204235316f92db0f6e1a1c2436..0cfab7ee7b66ed27f607559f4ff3bf199550b9ca 100644
--- a/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py
@@ -64,11 +64,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self._pkg_path / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/cxr8/datamodule.py b/src/mednet/libs/segmentation/config/data/cxr8/datamodule.py
index f1c0964839ed58508f629c229977593addc5411b..acc6c83bf0b3112a01d12cdebd8a615b4eabd5c1 100644
--- a/src/mednet/libs/segmentation/config/data/cxr8/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/cxr8/datamodule.py
@@ -91,11 +91,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
         target = np.where(target == 255, 0, target)
         target = to_tensor(PIL.Image.fromarray(np.array(target > 0)))
 
-        tensor = tv_tensors.Image(tensor)
-        target = tv_tensors.Image(target)
+        image = tv_tensors.Image(tensor)
+        target = tv_tensors.Mask(target)
         mask = tv_tensors.Mask(torch.ones_like(target))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/drhagis/datamodule.py b/src/mednet/libs/segmentation/config/data/drhagis/datamodule.py
index b7a13e818a17476e82520608e5f6b51522b93f45..153c0abe36023d4e57c147dbb1c43f35b42c816e 100644
--- a/src/mednet/libs/segmentation/config/data/drhagis/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drhagis/datamodule.py
@@ -57,11 +57,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/drionsdb/datamodule.py b/src/mednet/libs/segmentation/config/data/drionsdb/datamodule.py
index f5b9085de92c4782bee1426aab021299669b1513..41cc7fb8f02309bcb6d982e70e0d403bf5529ec6 100644
--- a/src/mednet/libs/segmentation/config/data/drionsdb/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drionsdb/datamodule.py
@@ -86,10 +86,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self._pkg_path / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/drishtigs1/datamodule.py b/src/mednet/libs/segmentation/config/data/drishtigs1/datamodule.py
index 2a463df0d438c29a788caa4fe89b6e49e5649315..1c65e40dce56d6066105808c61a037211770a13f 100644
--- a/src/mednet/libs/segmentation/config/data/drishtigs1/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drishtigs1/datamodule.py
@@ -83,11 +83,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self._pkg_path / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/hrf/datamodule.py b/src/mednet/libs/segmentation/config/data/hrf/datamodule.py
index 6227dd96e2e425fdc1d650e9a9587bd432f011e5..811ac24f9bc7244ffda037ef1c93b3fda8e8df2a 100644
--- a/src/mednet/libs/segmentation/config/data/hrf/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/hrf/datamodule.py
@@ -60,11 +60,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/iostar/datamodule.py b/src/mednet/libs/segmentation/config/data/iostar/datamodule.py
index 6c0e20fc19b6f30e37ae6cd41722fd079a831997..ff67de9e272c44ec2c594a9edb5a11382417af84 100644
--- a/src/mednet/libs/segmentation/config/data/iostar/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/iostar/datamodule.py
@@ -60,11 +60,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/jsrt/datamodule.py b/src/mednet/libs/segmentation/config/data/jsrt/datamodule.py
index fea4f9bd44a0919dc0e155dd469d5988ba791a6b..60f185b52a42974ddc29b8a247fe4139bc9b6162 100644
--- a/src/mednet/libs/segmentation/config/data/jsrt/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/jsrt/datamodule.py
@@ -97,11 +97,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             ).float()
         )
 
-        tensor = tv_tensors.Image(image)
-        target = tv_tensors.Image(target)
+        image = tv_tensors.Image(image)
+        target = tv_tensors.Mask(target)
         mask = tv_tensors.Mask(torch.ones_like(target))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py b/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py
index d239fbfd226cd32654801f351ab4c0dea0ea42ce..5792d941328fae7f18238bc73a4133b4cd2c0e1e 100644
--- a/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py
@@ -56,10 +56,10 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
         """
 
         image = PIL.Image.open(self.datadir / sample[0]).convert(mode="RGB")
-        tensor = tv_tensors.Image(to_tensor(image))
+        image = tv_tensors.Image(to_tensor(image))
 
         # Combine left and right lung masks into a single tensor
-        target = tv_tensors.Image(
+        target = tv_tensors.Mask(
             to_tensor(
                 np.ma.mask_or(
                     np.asarray(
@@ -78,7 +78,7 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
 
         mask = tv_tensors.Mask(torch.ones_like(target))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/refuge/datamodule.py b/src/mednet/libs/segmentation/config/data/refuge/datamodule.py
index 520a8a85f1656753336c5598cba44614f06bd952..a7bcd24d161df05c2c104571340b30271793a7c7 100644
--- a/src/mednet/libs/segmentation/config/data/refuge/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/refuge/datamodule.py
@@ -88,11 +88,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self._pkg_path / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/rimoner3/datamodule.py b/src/mednet/libs/segmentation/config/data/rimoner3/datamodule.py
index 1046406ef82768a1d8a5e863d5cd6a1d15ea3396..e4e9e5815c83c93afe95ed331c2489fa90544101 100644
--- a/src/mednet/libs/segmentation/config/data/rimoner3/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/rimoner3/datamodule.py
@@ -75,11 +75,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             )
         )
 
-        tensor = tv_tensors.Image(image)
-        target = tv_tensors.Image(target)
+        image = tv_tensors.Image(image)
+        target = tv_tensors.Mask(target)
         mask = tv_tensors.Mask(mask)
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/shenzhen/datamodule.py b/src/mednet/libs/segmentation/config/data/shenzhen/datamodule.py
index ade62cd9bb937239252b3c555477f3b1e49602e0..f0be322abb3a91607580c9bd686f844d9d824984 100644
--- a/src/mednet/libs/segmentation/config/data/shenzhen/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/shenzhen/datamodule.py
@@ -58,11 +58,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
 
         mask = torch.ones_like(target)
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/stare/datamodule.py b/src/mednet/libs/segmentation/config/data/stare/datamodule.py
index 14a352ab41ffc1d4c58bb16d240caac94c673805..9c8247823eb51532a8ef242e97a9d5fc4c12b96f 100644
--- a/src/mednet/libs/segmentation/config/data/stare/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/stare/datamodule.py
@@ -64,11 +64,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self._pkg_path / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):