Skip to content
Snippets Groups Projects
Commit a16730af authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[datamodule] More generic sample size computation

parent 6f2b1a75
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -30,62 +30,13 @@ from .typing import (
logger = logging.getLogger(__name__)
def _tensor_size_bytes(t: torch.Tensor) -> int:
"""Return a tensor size in bytes.
def _sample_size_bytes(dataset: Sample):
"""Recurse into the first sample of a dataset and figures out its total occupance in bytes.
Parameters
----------
t
A torch Tensor.
Returns
-------
int
The size of the Tensor in bytes.
"""
return int(t.element_size() * torch.prod(torch.tensor(t.shape)))
def _sample_size_bytes(s: Sample) -> int:
"""Recurse into the sample and figures out its total occupance in bytes.
Parameters
----------
s
The sample to be analyzed.
Returns
-------
int
The size in bytes occupied by this sample.
"""
size = sys.getsizeof(s[0]) # tensor metadata
size += int(s[0].element_size() * torch.prod(torch.tensor(s[0].shape)))
size += sys.getsizeof(s[1])
# check each element - if it is a tensor, then adds its total space in
# bytes
for v in s[1].values():
if isinstance(v, torch.Tensor):
size += _tensor_size_bytes(v)
return size
def _sample_dict_size_bytes(s: Sample) -> int:
"""Recurse into the sample and figures out its total occupance in bytes.
Parameters
----------
s
The sample to be analyzed.
Returns
-------
int
The size in bytes occupied by this sample.
dataset
The dataset containing the samples to load.
"""
def _tensor_size_bytes(t: torch.Tensor) -> int:
......@@ -102,42 +53,43 @@ def _sample_dict_size_bytes(s: Sample) -> int:
The size of the Tensor in bytes.
"""
logger.info(f"{list(t.shape)}@{t.dtype}")
return int(t.element_size() * torch.prod(torch.tensor(t.shape)))
size = sys.getsizeof(s[0]) # tensor metadata
size += sys.getsizeof(s[1])
def _dict_size_bytes(d):
"""Return a dictionary size in bytes.
# check each element - if it is a tensor, then adds its total space in
# bytes
for s_ in s:
for v in s_.values():
Parameters
----------
d
A dictionary.
Returns
-------
int
The size of the dictionary in bytes.
"""
size = 0
for v in d.values():
if isinstance(v, torch.Tensor):
size += _tensor_size_bytes(v)
return size
return size
size = 0
def _estimate_data_footprint(dataset):
"""Compute the estimated memory required to load samples in memory.
Parameters
----------
dataset
The dataset containing the samples to load.
"""
first_sample = dataset[0]
logger.info("Delayed loading dataset (first tensor):")
if isinstance(first_sample[0], dict):
for k, v in first_sample[0].items():
logger.info(f"{k}: {list(v.shape)}@{v.dtype}")
sample_size_mb = _sample_dict_size_bytes(first_sample) / (1024.0 * 1024.0)
logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
for s in first_sample:
size += sys.getsizeof(s)
if isinstance(s, dict):
size += _dict_size_bytes(s)
else:
size += _tensor_size_bytes(s)
else:
logger.info(f"{list(first_sample[0].shape)}@{first_sample[0].dtype}")
sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0)
logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
sample_size_mb = size / (1024.0 * 1024.0)
logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
def transform_tensors(data, transforms):
......@@ -197,7 +149,7 @@ class _DelayedLoadingDataset(Dataset):
self.loader = loader
self.transform = torchvision.transforms.Compose(transforms)
_estimate_data_footprint(self)
_sample_size_bytes(self)
def __getitem__(self, key: int) -> Sample:
tensor, metadata = self.loader.sample(self.raw_dataset[key])
......@@ -291,7 +243,7 @@ class _CachedDataset(Dataset):
),
)
_estimate_data_footprint(self)
_sample_size_bytes(self)
def targets(self) -> list[int | list[int]]:
"""Return the integer targets for all samples in the dataset.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment