Skip to content
Snippets Groups Projects
Commit 0218b171 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[datamodule] Apply model transforms on TVTensors

parent 59aefc94
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -15,6 +15,7 @@ import torch.backends
import torch.utils.data
import torchvision.transforms
import tqdm
from torchvision import tv_tensors
from .typing import (
ConcatDatabaseSplit,
......@@ -73,6 +74,32 @@ def _sample_size_bytes(s: Sample) -> int:
return size
def transform_tvtensors(data_dict, transforms):
"""Apply transforms on dictionary elements which are of TVTensor type.
Parameters
----------
data_dict
Dictionary possibly containing TVTensor elements to apply transforms to.
transforms
The transforms to apply.
Returns
-------
A dictionary with transforms applied on TVTensor elements.
"""
transformed_dict = {}
for k, m in data_dict.items():
if isinstance(m, tv_tensors.TVTensor):
transformed_dict[k] = transforms(m)
else:
transformed_dict[k] = m
return transformed_dict
class _DelayedLoadingDataset(Dataset):
"""A list that loads its samples on demand.
......@@ -123,7 +150,9 @@ class _DelayedLoadingDataset(Dataset):
def __getitem__(self, key: int) -> Sample:
tensor, metadata = self.loader.sample(self.raw_dataset[key])
return self.transform(tensor), metadata
return self.transform(tensor), transform_tvtensors(
metadata, self.transform
)
def __len__(self):
return len(self.raw_dataset)
......@@ -159,7 +188,9 @@ def _apply_loader_and_transforms(
"""
sample = load(info)
return model_transform(sample[0]), sample[1]
return model_transform(sample[0]), transform_tvtensors(
sample[1], model_transform
)
class _CachedDataset(Dataset):
......
......@@ -6,4 +6,4 @@
import typing
Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any]
"""Definition of a lightning checkpoint."""
\ No newline at end of file
"""Definition of a lightning checkpoint."""
......@@ -12,6 +12,7 @@ from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import JSONDatabaseSplit
from mednet.libs.common.data.typing import DatabaseSplit, Sample
from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader
from torchvision import tv_tensors
from torchvision.transforms.functional import to_tensor
from ....utils.rc import load_rc
......@@ -33,7 +34,7 @@ class RawDataLoader(_BaseRawDataLoader):
CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir)
)
def sample(self, sample: tuple[str, int]) -> Sample:
def sample(self, sample: tuple[str, str, str]) -> Sample:
"""Load a single image sample from the disk.
Parameters
......@@ -51,33 +52,23 @@ class RawDataLoader(_BaseRawDataLoader):
image = PIL.Image.open(Path(self.datadir) / str(sample[0])).convert(
mode="RGB"
)
tensor = to_tensor(image)
label = PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
mode="1", dither=None
tensor = tv_tensors.Image(to_tensor(image))
label = tv_tensors.Image(
to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[1])).convert(
mode="1", dither=None
)
)
)
mask = PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
mode="1", dither=None
mask = tv_tensors.Mask(
to_tensor(
PIL.Image.open(Path(self.datadir) / str(sample[2])).convert(
mode="1", dither=None
)
)
)
return tensor, dict(
label=to_tensor(label), mask=to_tensor(mask), name=sample[0]
) # type: ignore[arg-type]
def label(self, sample: tuple[str, int]) -> int:
"""Load a single image sample label from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
int
The integer label associated with the sample.
"""
return sample[1]
return tensor, dict(label=label, mask=mask, name=sample[0]) # type: ignore[arg-type]
def make_split(basename: str) -> DatabaseSplit:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment