From eb966eddb17c23e8455aa71042b6dc54fa1f5328 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 18 Jul 2023 17:12:20 +0200 Subject: [PATCH] Update dataset script to check splits --- src/ptbench/scripts/dataset.py | 35 ++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/src/ptbench/scripts/dataset.py b/src/ptbench/scripts/dataset.py index 7319cb54..5174abc8 100644 --- a/src/ptbench/scripts/dataset.py +++ b/src/ptbench/scripts/dataset.py @@ -4,6 +4,7 @@ from __future__ import annotations +import importlib.metadata import importlib.resources import os @@ -12,6 +13,8 @@ import click from clapper.click import AliasedGroup, verbosity_option from clapper.logging import setup +from ..data.split import check_database_split_loading + logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -149,7 +152,35 @@ def check(dataset, limit): errors = 0 for k in to_check.keys(): click.echo(f'Checking "{k}" dataset...') - module = importlib.import_module(f"...data.{k}", __name__) - errors += module.dataset.check(limit) + + # Gathering protocols for the dataset + entrypoints = [ + i + for i in importlib.metadata.entry_points( + group="ptbench.config" + ).names + if i == k + ] + protocols_modules = sorted( + [ + importlib.metadata.entry_points(group="ptbench.config")[ + i + ].module + for i in entrypoints + ] + ) + + for protocol in protocols_modules: + datamodule = importlib.import_module(protocol).datamodule + + database_split = datamodule.database_split + raw_data_loader = datamodule.raw_data_loader + + logger.info(f"Checking protocol {protocol}") + + errors += check_database_split_loading( + database_split.subsets, raw_data_loader, limit=limit + ) + if not errors: click.echo("No errors reported") -- GitLab