diff --git a/src/ptbench/scripts/evaluate.py b/src/ptbench/scripts/evaluate.py
index 08be4ea68afa5c73481cf29a1a593bb81eeb373b..3f4c00f6cd9215d8a0ad2c9d6bad7abcd0213162 100644
--- a/src/ptbench/scripts/evaluate.py
+++ b/src/ptbench/scripts/evaluate.py
@@ -5,7 +5,6 @@
 import os
 
 from collections import defaultdict
-from typing import Union
 
 import click
 
@@ -13,7 +12,7 @@ from clapper.click import ConfigCommand, ResourceOption, verbosity_option
 from clapper.logging import setup
 from matplotlib.backends.backend_pdf import PdfPages
 
-from ..data.datamodule import CachingDataModule
+from ..data.datamodule import ConcatDataModule
 from ..data.typing import DataLoader
 from ..utils.plot import precision_recall_f1iso, roc_curve
 from ..utils.table import performance_table
@@ -22,7 +21,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
 def _validate_threshold(
-    threshold: Union[int, float, str], dataloader_dict: dict[str, DataLoader]
+    threshold: int | float | str, dataloader_dict: dict[str, DataLoader]
 ):
     """Validates the user threshold selection.
 
@@ -140,8 +139,8 @@ def _validate_threshold(
 def evaluate(
     output_folder: str,
     predictions_folder: str,
-    datamodule: CachingDataModule,
-    threshold: Union[int, float, str],
+    datamodule: ConcatDataModule,
+    threshold: int | float | str,
     steps: int,
     **_,
 ) -> None: