diff --git a/src/ptbench/data/__init__.py b/src/ptbench/data/__init__.py index 8131b51bfa2fc0515eff966dedcad826cbd97ab8..259bab574ecf0f003486901d59a78a856e7ee37e 100644 --- a/src/ptbench/data/__init__.py +++ b/src/ptbench/data/__init__.py @@ -313,35 +313,35 @@ def return_subsets(dataset, protocol): if "train" in subsets.keys(): train_dataset = SampleListDataset(subsets["train"], []) - if "validation" in subsets.keys(): - validation_dataset = SampleListDataset(subsets["validation"], []) - else: - logger.warning( - "No validation dataset found, using training set instead." - ) - validation_dataset = train_dataset - - if "__extra_valid__" in subsets.keys(): - if not isinstance(subsets["__extra_valid__"], list): - raise RuntimeError( - f"If present, dataset['__extra_valid__'] must be a list, " - f"but you passed a {type(subsets['__extra_valid__'])}, " - f"which is invalid." - ) - logger.info( - f"Found {len(subsets['__extra_valid__'])} extra validation " - f"set(s) to be tracked during training" - ) - logger.info( - "Extra validation sets are NOT used for model checkpointing!" - ) - extra_validation_datasets = SampleListDataset( - subsets["__extra_valid__"], [] + if "validation" in subsets.keys(): + validation_dataset = SampleListDataset(subsets["validation"], []) + else: + logger.warning( + "No validation dataset found, using training set instead." + ) + validation_dataset = train_dataset + + if "__extra_valid__" in subsets.keys(): + if not isinstance(subsets["__extra_valid__"], list): + raise RuntimeError( + f"If present, dataset['__extra_valid__'] must be a list, " + f"but you passed a {type(subsets['__extra_valid__'])}, " + f"which is invalid." ) - else: - extra_validation_datasets = None + logger.info( + f"Found {len(subsets['__extra_valid__'])} extra validation " + f"set(s) to be tracked during training" + ) + logger.info( + "Extra validation sets are NOT used for model checkpointing!" + ) + extra_validation_datasets = SampleListDataset( + subsets["__extra_valid__"], [] + ) + else: + extra_validation_datasets = None - predict_dataset = subsets + predict_dataset = subsets return ( train_dataset,