From 7c34000dfe3c7781af5202aad47a4a4f0e6e4b83 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Thu, 27 Jul 2023 19:58:25 +0200 Subject: [PATCH] [scripts.evaluate] Correct datamodule typing --- src/ptbench/scripts/evaluate.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/ptbench/scripts/evaluate.py b/src/ptbench/scripts/evaluate.py index 08be4ea6..3f4c00f6 100644 --- a/src/ptbench/scripts/evaluate.py +++ b/src/ptbench/scripts/evaluate.py @@ -5,7 +5,6 @@ import os from collections import defaultdict -from typing import Union import click @@ -13,7 +12,7 @@ from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup from matplotlib.backends.backend_pdf import PdfPages -from ..data.datamodule import CachingDataModule +from ..data.datamodule import ConcatDataModule from ..data.typing import DataLoader from ..utils.plot import precision_recall_f1iso, roc_curve from ..utils.table import performance_table @@ -22,7 +21,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") def _validate_threshold( - threshold: Union[int, float, str], dataloader_dict: dict[str, DataLoader] + threshold: int | float | str, dataloader_dict: dict[str, DataLoader] ): """Validates the user threshold selection. @@ -140,8 +139,8 @@ def _validate_threshold( def evaluate( output_folder: str, predictions_folder: str, - datamodule: CachingDataModule, - threshold: Union[int, float, str], + datamodule: ConcatDataModule, + threshold: int | float | str, steps: int, **_, ) -> None: -- GitLab