From 0218b1714febe63ace1eb2866a012b7320674963 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 13 May 2024 12:16:29 +0200
Subject: [PATCH] [datamodule] Apply model transforms on TVTensors

---
 src/mednet/libs/common/data/datamodule.py     | 35 +++++++++++++++-
 src/mednet/libs/common/models/typing.py       |  2 +-
 .../config/data/drive/datamodule.py           | 41 ++++++++-----------
 3 files changed, 50 insertions(+), 28 deletions(-)

diff --git a/src/mednet/libs/common/data/datamodule.py b/src/mednet/libs/common/data/datamodule.py
index 3bfcb9f5..30fc26bd 100644
--- a/src/mednet/libs/common/data/datamodule.py
+++ b/src/mednet/libs/common/data/datamodule.py
@@ -15,6 +15,7 @@ import torch.backends
 import torch.utils.data
 import torchvision.transforms
 import tqdm
+from torchvision import tv_tensors
 
 from .typing import (
     ConcatDatabaseSplit,
@@ -73,6 +74,32 @@ def _sample_size_bytes(s: Sample) -> int:
     return size
 
 
+def transform_tvtensors(data_dict, transforms):
+    """Apply transforms on dictionary elements which are of TVTensor type.
+
+    Parameters
+    ----------
+    data_dict
+        Dictionary possibly containing TVTensor elements to apply transforms to.
+    transforms
+        The transforms to apply.
+
+    Returns
+    -------
+        A dictionary with transforms applied on TVTensor elements.
+    """
+
+    transformed_dict = {}
+
+    for k, m in data_dict.items():
+        if isinstance(m, tv_tensors.TVTensor):
+            transformed_dict[k] = transforms(m)
+        else:
+            transformed_dict[k] = m
+
+    return transformed_dict
+
+
 class _DelayedLoadingDataset(Dataset):
     """A list that loads its samples on demand.
 
@@ -123,7 +150,9 @@ class _DelayedLoadingDataset(Dataset):
 
     def __getitem__(self, key: int) -> Sample:
         tensor, metadata = self.loader.sample(self.raw_dataset[key])
-        return self.transform(tensor), metadata
+        return self.transform(tensor), transform_tvtensors(
+            metadata, self.transform
+        )
 
     def __len__(self):
         return len(self.raw_dataset)
@@ -159,7 +188,9 @@ def _apply_loader_and_transforms(
     """
 
     sample = load(info)
-    return model_transform(sample[0]), sample[1]
+    return model_transform(sample[0]), transform_tvtensors(
+        sample[1], model_transform
+    )
 
 
 class _CachedDataset(Dataset):
diff --git a/src/mednet/libs/common/models/typing.py b/src/mednet/libs/common/models/typing.py
index 97a04b66..f53cac9c 100644
--- a/src/mednet/libs/common/models/typing.py
+++ b/src/mednet/libs/common/models/typing.py
@@ -6,4 +6,4 @@
 import typing
 
 Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any]
-"""Definition of a lightning checkpoint."""
\ No newline at end of file
+"""Definition of a lightning checkpoint."""
diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
index ecab53d8..b61941a5 100644
--- a/src/mednet/libs/segmentation/config/data/drive/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
@@ -12,6 +12,7 @@ from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import JSONDatabaseSplit
 from mednet.libs.common.data.typing import DatabaseSplit, Sample
 from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader
+from torchvision import tv_tensors
 from torchvision.transforms.functional import to_tensor
 
 from ....utils.rc import load_rc
@@ -33,7 +34,7 @@ class RawDataLoader(_BaseRawDataLoader):
             CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)
         )
 
-    def sample(self, sample: tuple[str, int]) -> Sample:
+    def sample(self, sample: tuple[str, str, str]) -> Sample:
         """Load a single image sample from the disk.
 
         Parameters
@@ -51,33 +52,23 @@ class RawDataLoader(_BaseRawDataLoader):
         image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
             mode="RGB"
         )
-        tensor = to_tensor(image)
-        label = PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
-            mode="1", dither=None
+        tensor = tv_tensors.Image(to_tensor(image))
+        label = tv_tensors.Image(
+            to_tensor(
+                PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
+                    mode="1", dither=None
+                )
+            )
         )
-        mask = PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
-            mode="1", dither=None
+        mask = tv_tensors.Mask(
+            to_tensor(
+                PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
+                    mode="1", dither=None
+                )
+            )
         )
-        return tensor, dict(
-            label=to_tensor(label), mask=to_tensor(mask), name=sample[0]
-        )  # type: ignore[arg-type]
 
-    def label(self, sample: tuple[str, int]) -> int:
-        """Load a single image sample label from the disk.
-
-        Parameters
-        ----------
-        sample
-            A tuple containing the path suffix, within the dataset root folder,
-            where to find the image to be loaded, and an integer, representing the
-            sample label.
-
-        Returns
-        -------
-        int
-            The integer label associated with the sample.
-        """
-        return sample[1]
+        return tensor, dict(label=label, mask=mask, name=sample[0])  # type: ignore[arg-type]
 
 
 def make_split(basename: str) -> DatabaseSplit:
-- 
GitLab