diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 4c92476360b9895016560976706763f1098c74d3..bb2dcdda8ad100d91f6bbd35abd8bcb07e41a334 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -49,7 +49,8 @@ def _sample_size_bytes(s: Sample) -> int: """Returns a tensor size in bytes.""" return int(t.element_size() * torch.prod(torch.tensor(t.shape))) - size = int(s[0].element_size() * torch.prod(torch.tensor(s[0].shape))) + size = sys.getsizeof(s[0]) # tensor metadata + size += int(s[0].element_size() * torch.prod(torch.tensor(s[0].shape))) size += sys.getsizeof(s[1]) # check each element - if it is a tensor, then adds its total space in