From 00f74e43f54fdf54500b415145939fea2b9a317c Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 26 Jul 2023 15:58:48 +0200 Subject: [PATCH] Fix ptbench datamodule list --- src/ptbench/scripts/datamodule.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/ptbench/scripts/datamodule.py b/src/ptbench/scripts/datamodule.py index c16b6288..66681ea7 100644 --- a/src/ptbench/scripts/datamodule.py +++ b/src/ptbench/scripts/datamodule.py @@ -17,14 +17,14 @@ from ..data.split import check_database_split_loading logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") -def _get_installed_protocols() -> dict[str, str]: - """Returns a list of all installed protocols. +def _get_installed_datamodules() -> dict[str, str]: + """Returns a list of all installed datamodules. Returns ------- - protocols: - List of protocols. + datamodules: + List of datamodules. """ entrypoints = sorted( [ @@ -35,18 +35,20 @@ def _get_installed_protocols() -> dict[str, str]: ] ) - protocols = [ + datamodules = [ importlib.metadata.entry_points(group="ptbench.config")[ entrypoint ].module for entrypoint in entrypoints ] - protocols_dict = { - entrypoints[i]: protocols[i] for i in range(len(entrypoints)) + datamodules_dict = { + entrypoints[i]: datamodules[i] + for i in range(len(entrypoints)) + if datamodules[i].split(".")[1] == "data" } - return protocols_dict + return datamodules_dict @click.group(cls=AliasedGroup) @@ -74,20 +76,20 @@ def datamodule() -> None: This setting **is** case-sensitive. \b - 2. List all raw datasets supported (and configured): + 2. List all raw datamodule supported (and configured): .. code:: sh - $ ptbench dataset list + $ ptbench datamodule list """, ) @verbosity_option(logger=logger, expose_value=False) def list(): - """Lists all supported and configured datasets.""" - installed = _get_installed_protocols() + """Lists all supported and configured datamodules.""" + installed = _get_installed_datamodules() - click.echo("Available protocols:") + click.echo("Available datamodules:") for k, v in installed.items(): click.echo(f'- {k}: "{v}"') -- GitLab