diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py index 11c285bcc93380291f273e191281035462eb0d6f..6dc59200b2c751d7e304393a9b28abee18d40fa5 100644 --- a/src/ptbench/configs/datasets/shenzhen/default.py +++ b/src/ptbench/configs/datasets/shenzhen/default.py @@ -37,6 +37,7 @@ class DefaultModule(BaseDataModule): ) self.cache_samples = cache_samples + self.has_setup_fit = False def setup(self, stage: str): if self.cache_samples: @@ -56,12 +57,13 @@ class DefaultModule(BaseDataModule): loader=samples_loader, ) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.json_dataset, "default") + if not self.has_setup_fit and stage == "fit": + ( + self.train_dataset, + self.validation_dataset, + self.extra_validation_datasets, + ) = return_subsets(self.json_dataset, "default", stage) + self.has_setup_fit = True datamodule = DefaultModule diff --git a/src/ptbench/configs/datasets/shenzhen/rgb.py b/src/ptbench/configs/datasets/shenzhen/rgb.py index 0ceb952e28c0a1181772565b425043a9a08cb646..2506e79da82008b636160f88f353080627ec9e00 100644 --- a/src/ptbench/configs/datasets/shenzhen/rgb.py +++ b/src/ptbench/configs/datasets/shenzhen/rgb.py @@ -37,6 +37,7 @@ class DefaultModule(BaseDataModule): ) self.cache_samples = cache_samples + self.has_setup_fit = False self.post_transforms = [ transforms.ToPILImage(), @@ -63,12 +64,13 @@ class DefaultModule(BaseDataModule): post_transforms=self.post_transforms, ) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.json_dataset, "default") + if not self.has_setup_fit and stage == "fit": + ( + self.train_dataset, + self.validation_dataset, + self.extra_validation_datasets, + ) = return_subsets(self.json_dataset, "default", stage) + self.has_setup_fit = True datamodule = DefaultModule diff --git a/src/ptbench/data/__init__.py b/src/ptbench/data/__init__.py index 259bab574ecf0f003486901d59a78a856e7ee37e..682d5d1ad8e88ae4f8dd72dccb888200bc3d161a 100644 --- a/src/ptbench/data/__init__.py +++ b/src/ptbench/data/__init__.py @@ -303,49 +303,54 @@ def get_positive_weights(dataset): return positive_weights -def return_subsets(dataset, protocol): - train_dataset = None - validation_dataset = None - extra_validation_datasets = None - predict_dataset = None +def return_subsets(dataset, protocol, stage): + train_set = None + valid_set = None + extra_valid_sets = None subsets = dataset.subsets(protocol) - if "train" in subsets.keys(): - train_dataset = SampleListDataset(subsets["train"], []) - if "validation" in subsets.keys(): - validation_dataset = SampleListDataset(subsets["validation"], []) - else: - logger.warning( - "No validation dataset found, using training set instead." - ) - validation_dataset = train_dataset - - if "__extra_valid__" in subsets.keys(): - if not isinstance(subsets["__extra_valid__"], list): - raise RuntimeError( - f"If present, dataset['__extra_valid__'] must be a list, " - f"but you passed a {type(subsets['__extra_valid__'])}, " - f"which is invalid." + def get_train_subset(): + if "train" in subsets.keys(): + nonlocal train_set + train_set = SampleListDataset(subsets["train"], []) + + def get_valid_subset(): + if "validation" in subsets.keys(): + nonlocal valid_set + valid_set = SampleListDataset(subsets["validation"], []) + else: + logger.warning( + "No validation dataset found, using training set instead." ) - logger.info( - f"Found {len(subsets['__extra_valid__'])} extra validation " - f"set(s) to be tracked during training" - ) - logger.info( - "Extra validation sets are NOT used for model checkpointing!" - ) - extra_validation_datasets = SampleListDataset( - subsets["__extra_valid__"], [] - ) - else: - extra_validation_datasets = None + if train_set is None: + get_train_subset() + + valid_set = train_set + + def get_extra_valid_subset(): + if "__extra_valid__" in subsets.keys(): + if not isinstance(subsets["__extra_valid__"], list): + raise RuntimeError( + f"If present, dataset['__extra_valid__'] must be a list, " + f"but you passed a {type(subsets['__extra_valid__'])}, " + f"which is invalid." + ) + logger.info( + f"Found {len(subsets['__extra_valid__'])} extra validation " + f"set(s) to be tracked during training" + ) + logger.info( + "Extra validation sets are NOT used for model checkpointing!" + ) + nonlocal extra_valid_sets + extra_valid_sets = SampleListDataset(subsets["__extra_valid__"], []) - predict_dataset = subsets + if stage == "fit": + get_train_subset() + get_valid_subset() + get_extra_valid_subset() - return ( - train_dataset, - validation_dataset, - extra_validation_datasets, - predict_dataset, - ) + return train_set, valid_set, extra_valid_sets + else: + raise ValueError(f"Stage {stage} is unknown.")