diff --git a/src/ptbench/data/__init__.py b/src/ptbench/data/__init__.py
index 8131b51bfa2fc0515eff966dedcad826cbd97ab8..259bab574ecf0f003486901d59a78a856e7ee37e 100644
--- a/src/ptbench/data/__init__.py
+++ b/src/ptbench/data/__init__.py
@@ -313,35 +313,35 @@ def return_subsets(dataset, protocol):
     if "train" in subsets.keys():
         train_dataset = SampleListDataset(subsets["train"], [])
 
-        if "validation" in subsets.keys():
-            validation_dataset = SampleListDataset(subsets["validation"], [])
-        else:
-            logger.warning(
-                "No validation dataset found, using training set instead."
-            )
-            validation_dataset = train_dataset
-
-        if "__extra_valid__" in subsets.keys():
-            if not isinstance(subsets["__extra_valid__"], list):
-                raise RuntimeError(
-                    f"If present, dataset['__extra_valid__'] must be a list, "
-                    f"but you passed a {type(subsets['__extra_valid__'])}, "
-                    f"which is invalid."
-                )
-            logger.info(
-                f"Found {len(subsets['__extra_valid__'])} extra validation "
-                f"set(s) to be tracked during training"
-            )
-            logger.info(
-                "Extra validation sets are NOT used for model checkpointing!"
-            )
-            extra_validation_datasets = SampleListDataset(
-                subsets["__extra_valid__"], []
+    if "validation" in subsets.keys():
+        validation_dataset = SampleListDataset(subsets["validation"], [])
+    else:
+        logger.warning(
+            "No validation dataset found, using training set instead."
+        )
+        validation_dataset = train_dataset
+
+    if "__extra_valid__" in subsets.keys():
+        if not isinstance(subsets["__extra_valid__"], list):
+            raise RuntimeError(
+                f"If present, dataset['__extra_valid__'] must be a list, "
+                f"but you passed a {type(subsets['__extra_valid__'])}, "
+                f"which is invalid."
             )
-        else:
-            extra_validation_datasets = None
+        logger.info(
+            f"Found {len(subsets['__extra_valid__'])} extra validation "
+            f"set(s) to be tracked during training"
+        )
+        logger.info(
+            "Extra validation sets are NOT used for model checkpointing!"
+        )
+        extra_validation_datasets = SampleListDataset(
+            subsets["__extra_valid__"], []
+        )
+    else:
+        extra_validation_datasets = None
 
-        predict_dataset = subsets
+    predict_dataset = subsets
 
     return (
         train_dataset,