Skip to content
Snippets Groups Projects
Commit 00f74e43 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Fix ptbench datamodule list

parent cf28368b
No related branches found
No related tags found
No related merge requests found
Pipeline #76515 failed
...@@ -17,14 +17,14 @@ from ..data.split import check_database_split_loading ...@@ -17,14 +17,14 @@ from ..data.split import check_database_split_loading
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _get_installed_protocols() -> dict[str, str]: def _get_installed_datamodules() -> dict[str, str]:
"""Returns a list of all installed protocols. """Returns a list of all installed datamodules.
Returns Returns
------- -------
protocols: datamodules:
List of protocols. List of datamodules.
""" """
entrypoints = sorted( entrypoints = sorted(
[ [
...@@ -35,18 +35,20 @@ def _get_installed_protocols() -> dict[str, str]: ...@@ -35,18 +35,20 @@ def _get_installed_protocols() -> dict[str, str]:
] ]
) )
protocols = [ datamodules = [
importlib.metadata.entry_points(group="ptbench.config")[ importlib.metadata.entry_points(group="ptbench.config")[
entrypoint entrypoint
].module ].module
for entrypoint in entrypoints for entrypoint in entrypoints
] ]
protocols_dict = { datamodules_dict = {
entrypoints[i]: protocols[i] for i in range(len(entrypoints)) entrypoints[i]: datamodules[i]
for i in range(len(entrypoints))
if datamodules[i].split(".")[1] == "data"
} }
return protocols_dict return datamodules_dict
@click.group(cls=AliasedGroup) @click.group(cls=AliasedGroup)
...@@ -74,20 +76,20 @@ def datamodule() -> None: ...@@ -74,20 +76,20 @@ def datamodule() -> None:
This setting **is** case-sensitive. This setting **is** case-sensitive.
\b \b
2. List all raw datasets supported (and configured): 2. List all raw datamodule supported (and configured):
.. code:: sh .. code:: sh
$ ptbench dataset list $ ptbench datamodule list
""", """,
) )
@verbosity_option(logger=logger, expose_value=False) @verbosity_option(logger=logger, expose_value=False)
def list(): def list():
"""Lists all supported and configured datasets.""" """Lists all supported and configured datamodules."""
installed = _get_installed_protocols() installed = _get_installed_datamodules()
click.echo("Available protocols:") click.echo("Available datamodules:")
for k, v in installed.items(): for k, v in installed.items():
click.echo(f'- {k}: "{v}"') click.echo(f'- {k}: "{v}"')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment