diff --git a/src/ptbench/scripts/dataset.py b/src/ptbench/scripts/dataset.py index 7319cb546c09fe71785f5b5cffdba17ece1f9c57..5174abc8827df6e2f13bab7d5dbf48741482eeac 100644 --- a/src/ptbench/scripts/dataset.py +++ b/src/ptbench/scripts/dataset.py @@ -4,6 +4,7 @@ from __future__ import annotations +import importlib.metadata import importlib.resources import os @@ -12,6 +13,8 @@ import click from clapper.click import AliasedGroup, verbosity_option from clapper.logging import setup +from ..data.split import check_database_split_loading + logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -149,7 +152,35 @@ def check(dataset, limit): errors = 0 for k in to_check.keys(): click.echo(f'Checking "{k}" dataset...') - module = importlib.import_module(f"...data.{k}", __name__) - errors += module.dataset.check(limit) + + # Gathering protocols for the dataset + entrypoints = [ + i + for i in importlib.metadata.entry_points( + group="ptbench.config" + ).names + if i == k + ] + protocols_modules = sorted( + [ + importlib.metadata.entry_points(group="ptbench.config")[ + i + ].module + for i in entrypoints + ] + ) + + for protocol in protocols_modules: + datamodule = importlib.import_module(protocol).datamodule + + database_split = datamodule.database_split + raw_data_loader = datamodule.raw_data_loader + + logger.info(f"Checking protocol {protocol}") + + errors += check_database_split_loading( + database_split.subsets, raw_data_loader, limit=limit + ) + if not errors: click.echo("No errors reported")