diff --git a/src/mednet/libs/common/data/datamodule.py b/src/mednet/libs/common/data/datamodule.py index f0ed67b54431a9be5675b8aa3493752eb1dd1a6c..c2e9bea88026b513fa9ef741cf6313f71622fba1 100644 --- a/src/mednet/libs/common/data/datamodule.py +++ b/src/mednet/libs/common/data/datamodule.py @@ -15,7 +15,6 @@ import torch.backends import torch.utils.data import torchvision.transforms import tqdm -from torchvision import tv_tensors from .typing import ( ConcatDatabaseSplit, @@ -31,6 +30,23 @@ from .typing import ( logger = logging.getLogger(__name__) +def _tensor_size_bytes(t: torch.Tensor) -> int: + """Return a tensor size in bytes. + + Parameters + ---------- + t + A torch Tensor. + + Returns + ------- + int + The size of the Tensor in bytes. + """ + + return int(t.element_size() * torch.prod(torch.tensor(t.shape))) + + def _sample_size_bytes(s: Sample) -> int: """Recurse into the sample and figures out its total occupance in bytes. @@ -45,6 +61,33 @@ def _sample_size_bytes(s: Sample) -> int: The size in bytes occupied by this sample. """ + size = sys.getsizeof(s[0]) # tensor metadata + size += int(s[0].element_size() * torch.prod(torch.tensor(s[0].shape))) + size += sys.getsizeof(s[1]) + + # check each element - if it is a tensor, then adds its total space in + # bytes + for v in s[1].values(): + if isinstance(v, torch.Tensor): + size += _tensor_size_bytes(v) + + return size + + +def _sample_dict_size_bytes(s: Sample) -> int: + """Recurse into the sample and figures out its total occupance in bytes. + + Parameters + ---------- + s + The sample to be analyzed. + + Returns + ------- + int + The size in bytes occupied by this sample. + """ + def _tensor_size_bytes(t: torch.Tensor) -> int: """Return a tensor size in bytes. @@ -62,42 +105,68 @@ def _sample_size_bytes(s: Sample) -> int: return int(t.element_size() * torch.prod(torch.tensor(t.shape))) size = sys.getsizeof(s[0]) # tensor metadata - size += int(s[0].element_size() * torch.prod(torch.tensor(s[0].shape))) size += sys.getsizeof(s[1]) # check each element - if it is a tensor, then adds its total space in # bytes - for v in s[1].values(): - if isinstance(v, torch.Tensor): - size += _tensor_size_bytes(v) + for s_ in s: + for v in s_.values(): + if isinstance(v, torch.Tensor): + size += _tensor_size_bytes(v) return size -def transform_tvtensors(data_dict, transforms): - """Apply transforms on dictionary elements which are of TVTensor type. +def _estimate_data_footprint(dataset): + """Compute the estimated memory required to load samples in memory. Parameters ---------- - data_dict - Dictionary possibly containing TVTensor elements to apply transforms to. + dataset + The dataset containing the samples to load. + """ + first_sample = dataset[0] + + logger.info("Delayed loading dataset (first tensor):") + if isinstance(first_sample[0], dict): + for k, v in first_sample[0].items(): + logger.info(f"{k}: {list(v.shape)}@{v.dtype}") + sample_size_mb = _sample_dict_size_bytes(first_sample) / (1024.0 * 1024.0) + logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") + + else: + logger.info(f"{list(first_sample[0].shape)}@{first_sample[0].dtype}") + sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0) + logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") + + +def transform_tensors(data, transforms): + """Apply transforms on dictionary elements which are of Tensor type. + + Parameters + ---------- + data + Dictionary possibly containing Tensor elements to apply transforms to. transforms The transforms to apply. Returns ------- - A dictionary with transforms applied on TVTensor elements. + A dictionary with transforms applied on Tensor elements. """ - transformed_dict = {} + if isinstance(data, dict): + transformed_dict = {} - for k, m in data_dict.items(): - if isinstance(m, tv_tensors.TVTensor): - transformed_dict[k] = transforms(m) - else: - transformed_dict[k] = m + for k, m in data.items(): + if isinstance(m, torch.Tensor): + transformed_dict[k] = transforms(m) + else: + transformed_dict[k] = m + + return transformed_dict - return transformed_dict + return transforms(data) class _DelayedLoadingDataset(Dataset): @@ -128,18 +197,11 @@ class _DelayedLoadingDataset(Dataset): self.loader = loader self.transform = torchvision.transforms.Compose(transforms) - # Tests loading and output tensor size - first_sample = self[0] - logger.info( - f"Delayed loading dataset (first tensor): " - f"{list(first_sample[0].shape)}@{first_sample[0].dtype}", - ) - sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0) - logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") + _estimate_data_footprint(self) def __getitem__(self, key: int) -> Sample: tensor, metadata = self.loader.sample(self.raw_dataset[key]) - return self.transform(tensor), transform_tvtensors(metadata, self.transform) + return transform_tensors(tensor, self.transform), metadata def __len__(self): return len(self.raw_dataset) @@ -175,7 +237,7 @@ def _apply_loader_and_transforms( """ sample = load(info) - return model_transform(sample[0]), transform_tvtensors(sample[1], model_transform) + return transform_tensors(sample[0], model_transform), sample[1] class _CachedDataset(Dataset): @@ -229,16 +291,7 @@ class _CachedDataset(Dataset): ), ) - # Estimates memory occupance - logger.info( - f"Cached dataset (first tensor): " - f"{list(self.data[0][0].shape)}@{self.data[0][0].dtype}", - ) - sample_size_mb = _sample_size_bytes(self.data[0]) / (1024.0 * 1024.0) - logger.info( - f"Estimated RAM occupance (sample / dataset): " - f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb", - ) + _estimate_data_footprint(self) def targets(self) -> list[int | list[int]]: """Return the integer targets for all samples in the dataset. diff --git a/src/mednet/libs/common/engine/callbacks.py b/src/mednet/libs/common/engine/callbacks.py index 3d108a8d2c372478d7792de8d65a4b21e2ce3428..b030c8820f6a7ec748733ab9ba692f35f48f7d8f 100644 --- a/src/mednet/libs/common/engine/callbacks.py +++ b/src/mednet/libs/common/engine/callbacks.py @@ -195,13 +195,18 @@ class LoggingCallback(lightning.pytorch.Callback): The relative number of the batch. """ + if isinstance(batch[0], dict): + batch_size = batch[0]["image"].shape[0] + else: + batch_size = batch[0].shape[0] + pl_module.log( "loss/train", outputs["loss"].item(), prog_bar=True, on_step=False, on_epoch=True, - batch_size=batch[0].shape[0], + batch_size=batch_size, ) def on_validation_epoch_start( @@ -367,12 +372,17 @@ class LoggingCallback(lightning.pytorch.Callback): else: key = f"loss/validation-{dataloader_idx}" + if isinstance(batch[0], dict): + batch_size = batch[0]["image"].shape[0] + else: + batch_size = batch[0].shape[0] + pl_module.log( key, outputs.item(), prog_bar=False, on_step=False, on_epoch=True, - batch_size=batch[0].shape[0], + batch_size=batch_size, add_dataloader_idx=False, ) diff --git a/src/mednet/libs/segmentation/models/normalizer.py b/src/mednet/libs/segmentation/models/normalizer.py index df630bd1d61c9b7e10a951a80e8f560248f9a477..5b3b7561c292f7520d4e37cc5fe5e2e4f3be9836 100644 --- a/src/mednet/libs/segmentation/models/normalizer.py +++ b/src/mednet/libs/segmentation/models/normalizer.py @@ -44,7 +44,7 @@ def make_z_normalizer( # Evaluates mean and standard deviation for batch in tqdm.tqdm(dataloader, unit="batch"): - data = batch[0] + data = batch[0]["image"] data = data.view(data.size(0), data.size(1), -1) num_images += data.size(0)