diff --git a/src/mednet/libs/common/data/datamodule.py b/src/mednet/libs/common/data/datamodule.py index 3bfcb9f52846db648ef71b37835330e225d69d9c..30fc26bd9ce5b5817e4b45db7d31321c3ad5cffe 100644 --- a/src/mednet/libs/common/data/datamodule.py +++ b/src/mednet/libs/common/data/datamodule.py @@ -15,6 +15,7 @@ import torch.backends import torch.utils.data import torchvision.transforms import tqdm +from torchvision import tv_tensors from .typing import ( ConcatDatabaseSplit, @@ -73,6 +74,32 @@ def _sample_size_bytes(s: Sample) -> int: return size +def transform_tvtensors(data_dict, transforms): + """Apply transforms on dictionary elements which are of TVTensor type. + + Parameters + ---------- + data_dict + Dictionary possibly containing TVTensor elements to apply transforms to. + transforms + The transforms to apply. + + Returns + ------- + A dictionary with transforms applied on TVTensor elements. + """ + + transformed_dict = {} + + for k, m in data_dict.items(): + if isinstance(m, tv_tensors.TVTensor): + transformed_dict[k] = transforms(m) + else: + transformed_dict[k] = m + + return transformed_dict + + class _DelayedLoadingDataset(Dataset): """A list that loads its samples on demand. @@ -123,7 +150,9 @@ class _DelayedLoadingDataset(Dataset): def __getitem__(self, key: int) -> Sample: tensor, metadata = self.loader.sample(self.raw_dataset[key]) - return self.transform(tensor), metadata + return self.transform(tensor), transform_tvtensors( + metadata, self.transform + ) def __len__(self): return len(self.raw_dataset) @@ -159,7 +188,9 @@ def _apply_loader_and_transforms( """ sample = load(info) - return model_transform(sample[0]), sample[1] + return model_transform(sample[0]), transform_tvtensors( + sample[1], model_transform + ) class _CachedDataset(Dataset): diff --git a/src/mednet/libs/common/models/typing.py b/src/mednet/libs/common/models/typing.py index 97a04b6616724503e7690f6e8d33a40c315feed6..f53cac9ce1ffd61640e6e3f14c15c4ac364d201f 100644 --- a/src/mednet/libs/common/models/typing.py +++ b/src/mednet/libs/common/models/typing.py @@ -6,4 +6,4 @@ import typing Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any] -"""Definition of a lightning checkpoint.""" \ No newline at end of file +"""Definition of a lightning checkpoint.""" diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py index ecab53d8e9f3bbf598fb919ff97bd55c248cd6a4..b61941a59797ed80baf01f580cd6f92e50da5829 100644 --- a/src/mednet/libs/segmentation/config/data/drive/datamodule.py +++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py @@ -12,6 +12,7 @@ from mednet.libs.common.data.datamodule import CachingDataModule from mednet.libs.common.data.split import JSONDatabaseSplit from mednet.libs.common.data.typing import DatabaseSplit, Sample from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader +from torchvision import tv_tensors from torchvision.transforms.functional import to_tensor from ....utils.rc import load_rc @@ -33,7 +34,7 @@ class RawDataLoader(_BaseRawDataLoader): CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir) ) - def sample(self, sample: tuple[str, int]) -> Sample: + def sample(self, sample: tuple[str, str, str]) -> Sample: """Load a single image sample from the disk. Parameters @@ -51,33 +52,23 @@ class RawDataLoader(_BaseRawDataLoader): image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( mode="RGB" ) - tensor = to_tensor(image) - label = PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( - mode="1", dither=None + tensor = tv_tensors.Image(to_tensor(image)) + label = tv_tensors.Image( + to_tensor( + PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( + mode="1", dither=None + ) + ) ) - mask = PIL.Image.open(Path(self.datadir) / str(sample[2])).convert( - mode="1", dither=None + mask = tv_tensors.Mask( + to_tensor( + PIL.Image.open(Path(self.datadir) / str(sample[2])).convert( + mode="1", dither=None + ) + ) ) - return tensor, dict( - label=to_tensor(label), mask=to_tensor(mask), name=sample[0] - ) # type: ignore[arg-type] - def label(self, sample: tuple[str, int]) -> int: - """Load a single image sample label from the disk. - - Parameters - ---------- - sample - A tuple containing the path suffix, within the dataset root folder, - where to find the image to be loaded, and an integer, representing the - sample label. - - Returns - ------- - int - The integer label associated with the sample. - """ - return sample[1] + return tensor, dict(label=label, mask=mask, name=sample[0]) # type: ignore[arg-type] def make_split(basename: str) -> DatabaseSplit: