Skip to content
Snippets Groups Projects
Commit 0218b171 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[datamodule] Apply model transforms on TVTensors

parent 59aefc94
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -15,6 +15,7 @@ import torch.backends ...@@ -15,6 +15,7 @@ import torch.backends
import torch.utils.data import torch.utils.data
import torchvision.transforms import torchvision.transforms
import tqdm import tqdm
from torchvision import tv_tensors
from .typing import ( from .typing import (
ConcatDatabaseSplit, ConcatDatabaseSplit,
...@@ -73,6 +74,32 @@ def _sample_size_bytes(s: Sample) -> int: ...@@ -73,6 +74,32 @@ def _sample_size_bytes(s: Sample) -> int:
return size return size
def transform_tvtensors(data_dict, transforms):
"""Apply transforms on dictionary elements which are of TVTensor type.
Parameters
----------
data_dict
Dictionary possibly containing TVTensor elements to apply transforms to.
transforms
The transforms to apply.
Returns
-------
A dictionary with transforms applied on TVTensor elements.
"""
transformed_dict = {}
for k, m in data_dict.items():
if isinstance(m, tv_tensors.TVTensor):
transformed_dict[k] = transforms(m)
else:
transformed_dict[k] = m
return transformed_dict
class _DelayedLoadingDataset(Dataset): class _DelayedLoadingDataset(Dataset):
"""A list that loads its samples on demand. """A list that loads its samples on demand.
...@@ -123,7 +150,9 @@ class _DelayedLoadingDataset(Dataset): ...@@ -123,7 +150,9 @@ class _DelayedLoadingDataset(Dataset):
def __getitem__(self, key: int) -> Sample: def __getitem__(self, key: int) -> Sample:
tensor, metadata = self.loader.sample(self.raw_dataset[key]) tensor, metadata = self.loader.sample(self.raw_dataset[key])
return self.transform(tensor), metadata return self.transform(tensor), transform_tvtensors(
metadata, self.transform
)
def __len__(self): def __len__(self):
return len(self.raw_dataset) return len(self.raw_dataset)
...@@ -159,7 +188,9 @@ def _apply_loader_and_transforms( ...@@ -159,7 +188,9 @@ def _apply_loader_and_transforms(
""" """
sample = load(info) sample = load(info)
return model_transform(sample[0]), sample[1] return model_transform(sample[0]), transform_tvtensors(
sample[1], model_transform
)
class _CachedDataset(Dataset): class _CachedDataset(Dataset):
......
...@@ -6,4 +6,4 @@ ...@@ -6,4 +6,4 @@
import typing import typing
Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any] Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any]
"""Definition of a lightning checkpoint.""" """Definition of a lightning checkpoint."""
\ No newline at end of file
...@@ -12,6 +12,7 @@ from mednet.libs.common.data.datamodule import CachingDataModule ...@@ -12,6 +12,7 @@ from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import JSONDatabaseSplit from mednet.libs.common.data.split import JSONDatabaseSplit
from mednet.libs.common.data.typing import DatabaseSplit, Sample from mednet.libs.common.data.typing import DatabaseSplit, Sample
from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader
from torchvision import tv_tensors
from torchvision.transforms.functional import to_tensor from torchvision.transforms.functional import to_tensor
from ....utils.rc import load_rc from ....utils.rc import load_rc
...@@ -33,7 +34,7 @@ class RawDataLoader(_BaseRawDataLoader): ...@@ -33,7 +34,7 @@ class RawDataLoader(_BaseRawDataLoader):
CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir) CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)
) )
def sample(self, sample: tuple[str, int]) -> Sample: def sample(self, sample: tuple[str, str, str]) -> Sample:
"""Load a single image sample from the disk. """Load a single image sample from the disk.
Parameters Parameters
...@@ -51,33 +52,23 @@ class RawDataLoader(_BaseRawDataLoader): ...@@ -51,33 +52,23 @@ class RawDataLoader(_BaseRawDataLoader):
image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert( image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
mode="RGB" mode="RGB"
) )
tensor = to_tensor(image) tensor = tv_tensors.Image(to_tensor(image))
label = PIL.Image.open(Path(self.datadir) / str(sample[1])).convert( label = tv_tensors.Image(
mode="1", dither=None to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
mode="1", dither=None
)
)
) )
mask = PIL.Image.open(Path(self.datadir) / str(sample[2])).convert( mask = tv_tensors.Mask(
mode="1", dither=None to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
mode="1", dither=None
)
)
) )
return tensor, dict(
label=to_tensor(label), mask=to_tensor(mask), name=sample[0]
) # type: ignore[arg-type]
def label(self, sample: tuple[str, int]) -> int: return tensor, dict(label=label, mask=mask, name=sample[0]) # type: ignore[arg-type]
"""Load a single image sample label from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
int
The integer label associated with the sample.
"""
return sample[1]
def make_split(basename: str) -> DatabaseSplit: def make_split(basename: str) -> DatabaseSplit:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment