From a16730af8be428608749ee60b8913a2b826108b0 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 26 Jun 2024 13:33:26 +0200 Subject: [PATCH] [datamodule] More generic sample size computation --- src/mednet/libs/common/data/datamodule.py | 112 +++++++--------------- 1 file changed, 32 insertions(+), 80 deletions(-) diff --git a/src/mednet/libs/common/data/datamodule.py b/src/mednet/libs/common/data/datamodule.py index c2e9bea8..c8ffcffe 100644 --- a/src/mednet/libs/common/data/datamodule.py +++ b/src/mednet/libs/common/data/datamodule.py @@ -30,62 +30,13 @@ from .typing import ( logger = logging.getLogger(__name__) -def _tensor_size_bytes(t: torch.Tensor) -> int: - """Return a tensor size in bytes. +def _sample_size_bytes(dataset: Sample): + """Recurse into the first sample of a dataset and figures out its total occupance in bytes. Parameters ---------- - t - A torch Tensor. - - Returns - ------- - int - The size of the Tensor in bytes. - """ - - return int(t.element_size() * torch.prod(torch.tensor(t.shape))) - - -def _sample_size_bytes(s: Sample) -> int: - """Recurse into the sample and figures out its total occupance in bytes. - - Parameters - ---------- - s - The sample to be analyzed. - - Returns - ------- - int - The size in bytes occupied by this sample. - """ - - size = sys.getsizeof(s[0]) # tensor metadata - size += int(s[0].element_size() * torch.prod(torch.tensor(s[0].shape))) - size += sys.getsizeof(s[1]) - - # check each element - if it is a tensor, then adds its total space in - # bytes - for v in s[1].values(): - if isinstance(v, torch.Tensor): - size += _tensor_size_bytes(v) - - return size - - -def _sample_dict_size_bytes(s: Sample) -> int: - """Recurse into the sample and figures out its total occupance in bytes. - - Parameters - ---------- - s - The sample to be analyzed. - - Returns - ------- - int - The size in bytes occupied by this sample. + dataset + The dataset containing the samples to load. """ def _tensor_size_bytes(t: torch.Tensor) -> int: @@ -102,42 +53,43 @@ def _sample_dict_size_bytes(s: Sample) -> int: The size of the Tensor in bytes. """ + logger.info(f"{list(t.shape)}@{t.dtype}") return int(t.element_size() * torch.prod(torch.tensor(t.shape))) - size = sys.getsizeof(s[0]) # tensor metadata - size += sys.getsizeof(s[1]) + def _dict_size_bytes(d): + """Return a dictionary size in bytes. - # check each element - if it is a tensor, then adds its total space in - # bytes - for s_ in s: - for v in s_.values(): + Parameters + ---------- + d + A dictionary. + + Returns + ------- + int + The size of the dictionary in bytes. + """ + + size = 0 + for v in d.values(): if isinstance(v, torch.Tensor): size += _tensor_size_bytes(v) - return size + return size + size = 0 -def _estimate_data_footprint(dataset): - """Compute the estimated memory required to load samples in memory. - - Parameters - ---------- - dataset - The dataset containing the samples to load. - """ first_sample = dataset[0] - logger.info("Delayed loading dataset (first tensor):") - if isinstance(first_sample[0], dict): - for k, v in first_sample[0].items(): - logger.info(f"{k}: {list(v.shape)}@{v.dtype}") - sample_size_mb = _sample_dict_size_bytes(first_sample) / (1024.0 * 1024.0) - logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") + for s in first_sample: + size += sys.getsizeof(s) + if isinstance(s, dict): + size += _dict_size_bytes(s) + else: + size += _tensor_size_bytes(s) - else: - logger.info(f"{list(first_sample[0].shape)}@{first_sample[0].dtype}") - sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0) - logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") + sample_size_mb = size / (1024.0 * 1024.0) + logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") def transform_tensors(data, transforms): @@ -197,7 +149,7 @@ class _DelayedLoadingDataset(Dataset): self.loader = loader self.transform = torchvision.transforms.Compose(transforms) - _estimate_data_footprint(self) + _sample_size_bytes(self) def __getitem__(self, key: int) -> Sample: tensor, metadata = self.loader.sample(self.raw_dataset[key]) @@ -291,7 +243,7 @@ class _CachedDataset(Dataset): ), ) - _estimate_data_footprint(self) + _sample_size_bytes(self) def targets(self) -> list[int | list[int]]: """Return the integer targets for all samples in the dataset. -- GitLab