Skip to content
Snippets Groups Projects
Commit 7c34000d authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[scripts.evaluate] Correct datamodule typing

parent 79c79301
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -5,7 +5,6 @@
import os
from collections import defaultdict
from typing import Union
import click
......@@ -13,7 +12,7 @@ from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup
from matplotlib.backends.backend_pdf import PdfPages
from ..data.datamodule import CachingDataModule
from ..data.datamodule import ConcatDataModule
from ..data.typing import DataLoader
from ..utils.plot import precision_recall_f1iso, roc_curve
from ..utils.table import performance_table
......@@ -22,7 +21,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _validate_threshold(
threshold: Union[int, float, str], dataloader_dict: dict[str, DataLoader]
threshold: int | float | str, dataloader_dict: dict[str, DataLoader]
):
"""Validates the user threshold selection.
......@@ -140,8 +139,8 @@ def _validate_threshold(
def evaluate(
output_folder: str,
predictions_folder: str,
datamodule: CachingDataModule,
threshold: Union[int, float, str],
datamodule: ConcatDataModule,
threshold: int | float | str,
steps: int,
**_,
) -> None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment