From 7c34000dfe3c7781af5202aad47a4a4f0e6e4b83 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Thu, 27 Jul 2023 19:58:25 +0200
Subject: [PATCH] [scripts.evaluate] Correct datamodule typing

---
 src/ptbench/scripts/evaluate.py | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/src/ptbench/scripts/evaluate.py b/src/ptbench/scripts/evaluate.py
index 08be4ea6..3f4c00f6 100644
--- a/src/ptbench/scripts/evaluate.py
+++ b/src/ptbench/scripts/evaluate.py
@@ -5,7 +5,6 @@
 import os
 
 from collections import defaultdict
-from typing import Union
 
 import click
 
@@ -13,7 +12,7 @@ from clapper.click import ConfigCommand, ResourceOption, verbosity_option
 from clapper.logging import setup
 from matplotlib.backends.backend_pdf import PdfPages
 
-from ..data.datamodule import CachingDataModule
+from ..data.datamodule import ConcatDataModule
 from ..data.typing import DataLoader
 from ..utils.plot import precision_recall_f1iso, roc_curve
 from ..utils.table import performance_table
@@ -22,7 +21,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
 def _validate_threshold(
-    threshold: Union[int, float, str], dataloader_dict: dict[str, DataLoader]
+    threshold: int | float | str, dataloader_dict: dict[str, DataLoader]
 ):
     """Validates the user threshold selection.
 
@@ -140,8 +139,8 @@ def _validate_threshold(
 def evaluate(
     output_folder: str,
     predictions_folder: str,
-    datamodule: CachingDataModule,
-    threshold: Union[int, float, str],
+    datamodule: ConcatDataModule,
+    threshold: int | float | str,
     steps: int,
     **_,
 ) -> None:
-- 
GitLab