From f142d9681042bff18a188cfb0966022dbe5c91c9 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 26 Jun 2024 10:21:05 +0200
Subject: [PATCH] [segmentation.datamodule] Make sample() return two dicts

---
 .../libs/segmentation/config/data/avdrive/datamodule.py    | 6 +++---
 .../libs/segmentation/config/data/chasedb1/datamodule.py   | 6 +++---
 .../libs/segmentation/config/data/cxr8/datamodule.py       | 6 +++---
 .../libs/segmentation/config/data/drhagis/datamodule.py    | 6 +++---
 .../libs/segmentation/config/data/drionsdb/datamodule.py   | 7 ++++---
 .../libs/segmentation/config/data/drishtigs1/datamodule.py | 6 +++---
 src/mednet/libs/segmentation/config/data/hrf/datamodule.py | 6 +++---
 .../libs/segmentation/config/data/iostar/datamodule.py     | 6 +++---
 .../libs/segmentation/config/data/jsrt/datamodule.py       | 6 +++---
 .../libs/segmentation/config/data/montgomery/datamodule.py | 6 +++---
 .../libs/segmentation/config/data/refuge/datamodule.py     | 6 +++---
 .../libs/segmentation/config/data/rimoner3/datamodule.py   | 6 +++---
 .../libs/segmentation/config/data/shenzhen/datamodule.py   | 6 +++---
 .../libs/segmentation/config/data/stare/datamodule.py      | 6 +++---
 14 files changed, 43 insertions(+), 42 deletions(-)

diff --git a/src/mednet/libs/segmentation/config/data/avdrive/datamodule.py b/src/mednet/libs/segmentation/config/data/avdrive/datamodule.py
index be6adce1..83156a2c 100644
--- a/src/mednet/libs/segmentation/config/data/avdrive/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/avdrive/datamodule.py
@@ -66,11 +66,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             )
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py b/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py
index 3dc63ed5..0cfab7ee 100644
--- a/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py
@@ -64,11 +64,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self._pkg_path / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/cxr8/datamodule.py b/src/mednet/libs/segmentation/config/data/cxr8/datamodule.py
index f1c09648..acc6c83b 100644
--- a/src/mednet/libs/segmentation/config/data/cxr8/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/cxr8/datamodule.py
@@ -91,11 +91,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
         target = np.where(target == 255, 0, target)
         target = to_tensor(PIL.Image.fromarray(np.array(target > 0)))
 
-        tensor = tv_tensors.Image(tensor)
-        target = tv_tensors.Image(target)
+        image = tv_tensors.Image(tensor)
+        target = tv_tensors.Mask(target)
         mask = tv_tensors.Mask(torch.ones_like(target))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/drhagis/datamodule.py b/src/mednet/libs/segmentation/config/data/drhagis/datamodule.py
index b7a13e81..153c0abe 100644
--- a/src/mednet/libs/segmentation/config/data/drhagis/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drhagis/datamodule.py
@@ -57,11 +57,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/drionsdb/datamodule.py b/src/mednet/libs/segmentation/config/data/drionsdb/datamodule.py
index f5b9085d..41cc7fb8 100644
--- a/src/mednet/libs/segmentation/config/data/drionsdb/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drionsdb/datamodule.py
@@ -86,10 +86,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self._pkg_path / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/drishtigs1/datamodule.py b/src/mednet/libs/segmentation/config/data/drishtigs1/datamodule.py
index 2a463df0..1c65e40d 100644
--- a/src/mednet/libs/segmentation/config/data/drishtigs1/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drishtigs1/datamodule.py
@@ -83,11 +83,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self._pkg_path / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/hrf/datamodule.py b/src/mednet/libs/segmentation/config/data/hrf/datamodule.py
index 6227dd96..811ac24f 100644
--- a/src/mednet/libs/segmentation/config/data/hrf/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/hrf/datamodule.py
@@ -60,11 +60,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/iostar/datamodule.py b/src/mednet/libs/segmentation/config/data/iostar/datamodule.py
index 6c0e20fc..ff67de9e 100644
--- a/src/mednet/libs/segmentation/config/data/iostar/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/iostar/datamodule.py
@@ -60,11 +60,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/jsrt/datamodule.py b/src/mednet/libs/segmentation/config/data/jsrt/datamodule.py
index fea4f9bd..60f185b5 100644
--- a/src/mednet/libs/segmentation/config/data/jsrt/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/jsrt/datamodule.py
@@ -97,11 +97,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             ).float()
         )
 
-        tensor = tv_tensors.Image(image)
-        target = tv_tensors.Image(target)
+        image = tv_tensors.Image(image)
+        target = tv_tensors.Mask(target)
         mask = tv_tensors.Mask(torch.ones_like(target))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py b/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py
index d239fbfd..5792d941 100644
--- a/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py
@@ -56,10 +56,10 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
         """
 
         image = PIL.Image.open(self.datadir / sample[0]).convert(mode="RGB")
-        tensor = tv_tensors.Image(to_tensor(image))
+        image = tv_tensors.Image(to_tensor(image))
 
         # Combine left and right lung masks into a single tensor
-        target = tv_tensors.Image(
+        target = tv_tensors.Mask(
             to_tensor(
                 np.ma.mask_or(
                     np.asarray(
@@ -78,7 +78,7 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
 
         mask = tv_tensors.Mask(torch.ones_like(target))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/refuge/datamodule.py b/src/mednet/libs/segmentation/config/data/refuge/datamodule.py
index 520a8a85..a7bcd24d 100644
--- a/src/mednet/libs/segmentation/config/data/refuge/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/refuge/datamodule.py
@@ -88,11 +88,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self._pkg_path / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/rimoner3/datamodule.py b/src/mednet/libs/segmentation/config/data/rimoner3/datamodule.py
index 1046406e..e4e9e581 100644
--- a/src/mednet/libs/segmentation/config/data/rimoner3/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/rimoner3/datamodule.py
@@ -75,11 +75,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             )
         )
 
-        tensor = tv_tensors.Image(image)
-        target = tv_tensors.Image(target)
+        image = tv_tensors.Image(image)
+        target = tv_tensors.Mask(target)
         mask = tv_tensors.Mask(mask)
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/shenzhen/datamodule.py b/src/mednet/libs/segmentation/config/data/shenzhen/datamodule.py
index ade62cd9..f0be322a 100644
--- a/src/mednet/libs/segmentation/config/data/shenzhen/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/shenzhen/datamodule.py
@@ -58,11 +58,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
 
         mask = torch.ones_like(target)
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/config/data/stare/datamodule.py b/src/mednet/libs/segmentation/config/data/stare/datamodule.py
index 14a352ab..9c824782 100644
--- a/src/mednet/libs/segmentation/config/data/stare/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/stare/datamodule.py
@@ -64,11 +64,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self._pkg_path / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
-- 
GitLab