From 00f74e43f54fdf54500b415145939fea2b9a317c Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 26 Jul 2023 15:58:48 +0200
Subject: [PATCH] Fix ptbench datamodule list

---
 src/ptbench/scripts/datamodule.py | 28 +++++++++++++++-------------
 1 file changed, 15 insertions(+), 13 deletions(-)

diff --git a/src/ptbench/scripts/datamodule.py b/src/ptbench/scripts/datamodule.py
index c16b6288..66681ea7 100644
--- a/src/ptbench/scripts/datamodule.py
+++ b/src/ptbench/scripts/datamodule.py
@@ -17,14 +17,14 @@ from ..data.split import check_database_split_loading
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
-def _get_installed_protocols() -> dict[str, str]:
-    """Returns a list of all installed protocols.
+def _get_installed_datamodules() -> dict[str, str]:
+    """Returns a list of all installed datamodules.
 
     Returns
     -------
 
-    protocols:
-        List of protocols.
+    datamodules:
+        List of datamodules.
     """
     entrypoints = sorted(
         [
@@ -35,18 +35,20 @@ def _get_installed_protocols() -> dict[str, str]:
         ]
     )
 
-    protocols = [
+    datamodules = [
         importlib.metadata.entry_points(group="ptbench.config")[
             entrypoint
         ].module
         for entrypoint in entrypoints
     ]
 
-    protocols_dict = {
-        entrypoints[i]: protocols[i] for i in range(len(entrypoints))
+    datamodules_dict = {
+        entrypoints[i]: datamodules[i]
+        for i in range(len(entrypoints))
+        if datamodules[i].split(".")[1] == "data"
     }
 
-    return protocols_dict
+    return datamodules_dict
 
 
 @click.group(cls=AliasedGroup)
@@ -74,20 +76,20 @@ def datamodule() -> None:
           This setting **is** case-sensitive.
 
 \b
-    2. List all raw datasets supported (and configured):
+    2. List all raw datamodule supported (and configured):
 
        .. code:: sh
 
-          $ ptbench dataset list
+          $ ptbench datamodule list
 
 """,
 )
 @verbosity_option(logger=logger, expose_value=False)
 def list():
-    """Lists all supported and configured datasets."""
-    installed = _get_installed_protocols()
+    """Lists all supported and configured datamodules."""
+    installed = _get_installed_datamodules()
 
-    click.echo("Available protocols:")
+    click.echo("Available datamodules:")
     for k, v in installed.items():
         click.echo(f'- {k}: "{v}"')
 
-- 
GitLab