Skip to content
Snippets Groups Projects
Commit 1ab24d4b authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.datamodule] Fixes cached dataset model transform application; Improves...

[data.datamodule] Fixes cached dataset model transform application; Improves logging and documentation; Estimate loaded sample sizes and RAM occupance
parent ba8c1eb0
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -3,6 +3,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import collections
import functools
import logging
import multiprocessing
import sys
......@@ -27,6 +28,39 @@ from .typing import (
logger = logging.getLogger(__name__)
def _sample_size_bytes(s: Sample) -> int:
"""Recurse into the sample and figures out its total occupance in bytes.
Parameters
----------
s
The sample to be analyzed
Returns
-------
size
The size in bytes occupied by this sample
"""
def _tensor_size_bytes(t: torch.Tensor) -> int:
"""Returns a tensor size in bytes."""
return int(t.element_size() * torch.prod(torch.tensor(t.shape)))
size = int(s[0].element_size() * torch.prod(torch.tensor(s[0].shape)))
size += sys.getsizeof(s[1])
# check each element - if it is a tensor, then adds its total space in
# bytes
for v in s[1].values():
if isinstance(v, torch.Tensor):
size += _tensor_size_bytes(v)
return size
class _DelayedLoadingDataset(Dataset):
"""A list that loads its samples on demand.
......@@ -59,6 +93,15 @@ class _DelayedLoadingDataset(Dataset):
self.loader = loader
self.transform = torchvision.transforms.Compose(transforms)
# Tests loading and output tensor size
first_sample = self[0]
logger.info(
f"Delayed loading dataset (first tensor): "
f"{list(first_sample[0].shape)}@{first_sample[0].dtype}"
)
sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0)
logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
def labels(self) -> list[int]:
"""Returns the integer labels for all samples in the dataset."""
return [self.loader.label(k) for k in self.split]
......@@ -75,6 +118,39 @@ class _DelayedLoadingDataset(Dataset):
yield self[x]
def _apply_loader_and_transforms(
info: typing.Any,
load: typing.Callable[[typing.Any], Sample],
model_transform: typing.Callable[[torch.Tensor], torch.Tensor],
) -> Sample:
"""Local wrapper to apply raw-data loading and transformation in a single
step.
Parameters
----------
info
The sample information, as loaded from its split dictionary
load
The raw-data loader function to use for loading the sample
model_transform
A callable that will transform the loaded tensor into something
suitable for the model it will train. Typically, this will be a
composed transform.
Returns
-------
sample
The loaded and transformed sample.
"""
sample = load(info)
return model_transform(sample[0]), sample[1]
class _CachedDataset(Dataset):
"""Basically, a list of preloaded samples.
......@@ -112,27 +188,41 @@ class _CachedDataset(Dataset):
parallel: int = -1,
transforms: typing.Sequence[Transform] = [],
):
self.transform = torchvision.transforms.Compose(transforms)
self.loader = functools.partial(
_apply_loader_and_transforms,
load=loader.sample,
model_transform=torchvision.transforms.Compose(transforms),
)
if parallel < 0:
self.data = [
loader.sample(k) for k in tqdm.tqdm(split, unit="sample")
self.loader(k) for k in tqdm.tqdm(split, unit="sample")
]
else:
instances = parallel or multiprocessing.cpu_count()
logger.info(f"Caching dataset using {instances} processes...")
with multiprocessing.Pool(instances) as p:
self.data = list(
tqdm.tqdm(p.imap(loader.sample, split), total=len(split))
tqdm.tqdm(p.imap(self.loader, split), total=len(split))
)
# Estimates memory occupance
logger.info(
f"Cached dataset (first tensor): "
f"{list(self.data[0][0].shape)}@{self.data[0][0].dtype}"
)
sample_size_mb = _sample_size_bytes(self.data[0]) / (1024.0 * 1024.0)
logger.info(
f"Estimated RAM occupance (sample / dataset): "
f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb"
)
def labels(self) -> list[int]:
"""Returns the integer labels for all samples in the dataset."""
return [k[1]["label"] for k in self.data]
def __getitem__(self, key: int) -> Sample:
tensor, metadata = self.data[key]
return self.transform(tensor), metadata
return self.data[key]
def __len__(self):
return len(self.data)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment