Skip to content
Snippets Groups Projects
Commit 00f74e43 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Fix ptbench datamodule list

parent cf28368b
No related branches found
No related tags found
No related merge requests found
Pipeline #76515 failed
......@@ -17,14 +17,14 @@ from ..data.split import check_database_split_loading
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _get_installed_protocols() -> dict[str, str]:
"""Returns a list of all installed protocols.
def _get_installed_datamodules() -> dict[str, str]:
"""Returns a list of all installed datamodules.
Returns
-------
protocols:
List of protocols.
datamodules:
List of datamodules.
"""
entrypoints = sorted(
[
......@@ -35,18 +35,20 @@ def _get_installed_protocols() -> dict[str, str]:
]
)
protocols = [
datamodules = [
importlib.metadata.entry_points(group="ptbench.config")[
entrypoint
].module
for entrypoint in entrypoints
]
protocols_dict = {
entrypoints[i]: protocols[i] for i in range(len(entrypoints))
datamodules_dict = {
entrypoints[i]: datamodules[i]
for i in range(len(entrypoints))
if datamodules[i].split(".")[1] == "data"
}
return protocols_dict
return datamodules_dict
@click.group(cls=AliasedGroup)
......@@ -74,20 +76,20 @@ def datamodule() -> None:
This setting **is** case-sensitive.
\b
2. List all raw datasets supported (and configured):
2. List all raw datamodule supported (and configured):
.. code:: sh
$ ptbench dataset list
$ ptbench datamodule list
""",
)
@verbosity_option(logger=logger, expose_value=False)
def list():
"""Lists all supported and configured datasets."""
installed = _get_installed_protocols()
"""Lists all supported and configured datamodules."""
installed = _get_installed_datamodules()
click.echo("Available protocols:")
click.echo("Available datamodules:")
for k, v in installed.items():
click.echo(f'- {k}: "{v}"')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment