From b61397ae1d7ca2c7f202e66f6ec256a3bfbdbfbf Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 14 Jun 2023 11:00:46 +0200 Subject: [PATCH] Setup datamodule only once per stage, only get needed subsets --- .../configs/datasets/shenzhen/default.py | 14 +-- src/ptbench/configs/datasets/shenzhen/rgb.py | 14 +-- src/ptbench/data/__init__.py | 85 ++++++++++--------- 3 files changed, 61 insertions(+), 52 deletions(-) diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py index 11c285bc..6dc59200 100644 --- a/src/ptbench/configs/datasets/shenzhen/default.py +++ b/src/ptbench/configs/datasets/shenzhen/default.py @@ -37,6 +37,7 @@ class DefaultModule(BaseDataModule): ) self.cache_samples = cache_samples + self.has_setup_fit = False def setup(self, stage: str): if self.cache_samples: @@ -56,12 +57,13 @@ class DefaultModule(BaseDataModule): loader=samples_loader, ) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.json_dataset, "default") + if not self.has_setup_fit and stage == "fit": + ( + self.train_dataset, + self.validation_dataset, + self.extra_validation_datasets, + ) = return_subsets(self.json_dataset, "default", stage) + self.has_setup_fit = True datamodule = DefaultModule diff --git a/src/ptbench/configs/datasets/shenzhen/rgb.py b/src/ptbench/configs/datasets/shenzhen/rgb.py index 0ceb952e..2506e79d 100644 --- a/src/ptbench/configs/datasets/shenzhen/rgb.py +++ b/src/ptbench/configs/datasets/shenzhen/rgb.py @@ -37,6 +37,7 @@ class DefaultModule(BaseDataModule): ) self.cache_samples = cache_samples + self.has_setup_fit = False self.post_transforms = [ transforms.ToPILImage(), @@ -63,12 +64,13 @@ class DefaultModule(BaseDataModule): post_transforms=self.post_transforms, ) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.json_dataset, "default") + if not self.has_setup_fit and stage == "fit": + ( + self.train_dataset, + self.validation_dataset, + self.extra_validation_datasets, + ) = return_subsets(self.json_dataset, "default", stage) + self.has_setup_fit = True datamodule = DefaultModule diff --git a/src/ptbench/data/__init__.py b/src/ptbench/data/__init__.py index 259bab57..682d5d1a 100644 --- a/src/ptbench/data/__init__.py +++ b/src/ptbench/data/__init__.py @@ -303,49 +303,54 @@ def get_positive_weights(dataset): return positive_weights -def return_subsets(dataset, protocol): - train_dataset = None - validation_dataset = None - extra_validation_datasets = None - predict_dataset = None +def return_subsets(dataset, protocol, stage): + train_set = None + valid_set = None + extra_valid_sets = None subsets = dataset.subsets(protocol) - if "train" in subsets.keys(): - train_dataset = SampleListDataset(subsets["train"], []) - if "validation" in subsets.keys(): - validation_dataset = SampleListDataset(subsets["validation"], []) - else: - logger.warning( - "No validation dataset found, using training set instead." - ) - validation_dataset = train_dataset - - if "__extra_valid__" in subsets.keys(): - if not isinstance(subsets["__extra_valid__"], list): - raise RuntimeError( - f"If present, dataset['__extra_valid__'] must be a list, " - f"but you passed a {type(subsets['__extra_valid__'])}, " - f"which is invalid." + def get_train_subset(): + if "train" in subsets.keys(): + nonlocal train_set + train_set = SampleListDataset(subsets["train"], []) + + def get_valid_subset(): + if "validation" in subsets.keys(): + nonlocal valid_set + valid_set = SampleListDataset(subsets["validation"], []) + else: + logger.warning( + "No validation dataset found, using training set instead." ) - logger.info( - f"Found {len(subsets['__extra_valid__'])} extra validation " - f"set(s) to be tracked during training" - ) - logger.info( - "Extra validation sets are NOT used for model checkpointing!" - ) - extra_validation_datasets = SampleListDataset( - subsets["__extra_valid__"], [] - ) - else: - extra_validation_datasets = None + if train_set is None: + get_train_subset() + + valid_set = train_set + + def get_extra_valid_subset(): + if "__extra_valid__" in subsets.keys(): + if not isinstance(subsets["__extra_valid__"], list): + raise RuntimeError( + f"If present, dataset['__extra_valid__'] must be a list, " + f"but you passed a {type(subsets['__extra_valid__'])}, " + f"which is invalid." + ) + logger.info( + f"Found {len(subsets['__extra_valid__'])} extra validation " + f"set(s) to be tracked during training" + ) + logger.info( + "Extra validation sets are NOT used for model checkpointing!" + ) + nonlocal extra_valid_sets + extra_valid_sets = SampleListDataset(subsets["__extra_valid__"], []) - predict_dataset = subsets + if stage == "fit": + get_train_subset() + get_valid_subset() + get_extra_valid_subset() - return ( - train_dataset, - validation_dataset, - extra_validation_datasets, - predict_dataset, - ) + return train_set, valid_set, extra_valid_sets + else: + raise ValueError(f"Stage {stage} is unknown.") -- GitLab