diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index c8c33dec3e0f417c838da720cc630bc04a59bd96..4c92476360b9895016560976706763f1098c74d3 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import collections +import functools import logging import multiprocessing import sys @@ -27,6 +28,39 @@ from .typing import ( logger = logging.getLogger(__name__) +def _sample_size_bytes(s: Sample) -> int: + """Recurse into the sample and figures out its total occupance in bytes. + + Parameters + ---------- + + s + The sample to be analyzed + + + Returns + ------- + + size + The size in bytes occupied by this sample + """ + + def _tensor_size_bytes(t: torch.Tensor) -> int: + """Returns a tensor size in bytes.""" + return int(t.element_size() * torch.prod(torch.tensor(t.shape))) + + size = int(s[0].element_size() * torch.prod(torch.tensor(s[0].shape))) + size += sys.getsizeof(s[1]) + + # check each element - if it is a tensor, then adds its total space in + # bytes + for v in s[1].values(): + if isinstance(v, torch.Tensor): + size += _tensor_size_bytes(v) + + return size + + class _DelayedLoadingDataset(Dataset): """A list that loads its samples on demand. @@ -59,6 +93,15 @@ class _DelayedLoadingDataset(Dataset): self.loader = loader self.transform = torchvision.transforms.Compose(transforms) + # Tests loading and output tensor size + first_sample = self[0] + logger.info( + f"Delayed loading dataset (first tensor): " + f"{list(first_sample[0].shape)}@{first_sample[0].dtype}" + ) + sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0) + logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") + def labels(self) -> list[int]: """Returns the integer labels for all samples in the dataset.""" return [self.loader.label(k) for k in self.split] @@ -75,6 +118,39 @@ class _DelayedLoadingDataset(Dataset): yield self[x] +def _apply_loader_and_transforms( + info: typing.Any, + load: typing.Callable[[typing.Any], Sample], + model_transform: typing.Callable[[torch.Tensor], torch.Tensor], +) -> Sample: + """Local wrapper to apply raw-data loading and transformation in a single + step. + + Parameters + ---------- + + info + The sample information, as loaded from its split dictionary + + load + The raw-data loader function to use for loading the sample + + model_transform + A callable that will transform the loaded tensor into something + suitable for the model it will train. Typically, this will be a + composed transform. + + + Returns + ------- + + sample + The loaded and transformed sample. + """ + sample = load(info) + return model_transform(sample[0]), sample[1] + + class _CachedDataset(Dataset): """Basically, a list of preloaded samples. @@ -112,27 +188,41 @@ class _CachedDataset(Dataset): parallel: int = -1, transforms: typing.Sequence[Transform] = [], ): - self.transform = torchvision.transforms.Compose(transforms) + self.loader = functools.partial( + _apply_loader_and_transforms, + load=loader.sample, + model_transform=torchvision.transforms.Compose(transforms), + ) if parallel < 0: self.data = [ - loader.sample(k) for k in tqdm.tqdm(split, unit="sample") + self.loader(k) for k in tqdm.tqdm(split, unit="sample") ] else: instances = parallel or multiprocessing.cpu_count() logger.info(f"Caching dataset using {instances} processes...") with multiprocessing.Pool(instances) as p: self.data = list( - tqdm.tqdm(p.imap(loader.sample, split), total=len(split)) + tqdm.tqdm(p.imap(self.loader, split), total=len(split)) ) + # Estimates memory occupance + logger.info( + f"Cached dataset (first tensor): " + f"{list(self.data[0][0].shape)}@{self.data[0][0].dtype}" + ) + sample_size_mb = _sample_size_bytes(self.data[0]) / (1024.0 * 1024.0) + logger.info( + f"Estimated RAM occupance (sample / dataset): " + f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb" + ) + def labels(self) -> list[int]: """Returns the integer labels for all samples in the dataset.""" return [k[1]["label"] for k in self.data] def __getitem__(self, key: int) -> Sample: - tensor, metadata = self.data[key] - return self.transform(tensor), metadata + return self.data[key] def __len__(self): return len(self.data)