From 6d819c3d790b4ce675a1f01dfaca9d4e6188e16a Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 25 Jul 2023 16:43:23 +0200
Subject: [PATCH] Rename dataset script to datamodule and update

---
 src/ptbench/scripts/cli.py        |   6 +-
 src/ptbench/scripts/datamodule.py | 149 ++++++++++++++++++++++++
 src/ptbench/scripts/dataset.py    | 186 ------------------------------
 3 files changed, 153 insertions(+), 188 deletions(-)
 create mode 100644 src/ptbench/scripts/datamodule.py
 delete mode 100644 src/ptbench/scripts/dataset.py

diff --git a/src/ptbench/scripts/cli.py b/src/ptbench/scripts/cli.py
index 1be33f6e..479d6b6b 100644
--- a/src/ptbench/scripts/cli.py
+++ b/src/ptbench/scripts/cli.py
@@ -10,8 +10,9 @@ from . import (
     aggregpred,
     compare,
     config,
-    dataset,
+    datamodule,
     evaluate,
+    experiment,
     predict,
     predtojson,
     train,
@@ -31,8 +32,9 @@ def cli():
 cli.add_command(aggregpred.aggregpred)
 cli.add_command(compare.compare)
 cli.add_command(config.config)
-cli.add_command(dataset.dataset)
+cli.add_command(datamodule.datamodule)
 cli.add_command(evaluate.evaluate)
+cli.add_command(experiment.experiment)
 cli.add_command(predict.predict)
 cli.add_command(predtojson.predtojson)
 cli.add_command(train.train)
diff --git a/src/ptbench/scripts/datamodule.py b/src/ptbench/scripts/datamodule.py
new file mode 100644
index 00000000..6cadcc72
--- /dev/null
+++ b/src/ptbench/scripts/datamodule.py
@@ -0,0 +1,149 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from __future__ import annotations
+
+import importlib.metadata
+import importlib.resources
+
+import click
+
+from clapper.click import AliasedGroup, verbosity_option
+from clapper.logging import setup
+
+from ..data.split import check_database_split_loading
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+def _get_installed_protocols() -> dict[str, str]:
+    """Returns a list of all installed protocols.
+
+    Returns
+    -------
+
+    protocols:
+        List of protocols.
+    """
+    entrypoints = sorted(
+        [
+            entrypoint
+            for entrypoint in importlib.metadata.entry_points(
+                group="ptbench.config"
+            ).names
+        ]
+    )
+
+    protocols = [
+        importlib.metadata.entry_points(group="ptbench.config")[
+            entrypoint
+        ].module
+        for entrypoint in entrypoints
+    ]
+
+    protocols_dict = {
+        entrypoints[i]: protocols[i] for i in range(len(entrypoints))
+    }
+
+    return protocols_dict
+
+
+@click.group(cls=AliasedGroup)
+def datamodule() -> None:
+    """Commands for listing and verifying datamodules."""
+    pass
+
+
+@datamodule.command(
+    epilog="""Examples:
+
+\b
+    1. To install a dataset, set up its data directory ("datadir").  For
+       example, to setup access to Montgomery files you downloaded locally at
+       the directory "/path/to/montgomery/files", edit the RC file (typically
+       ``$HOME/.config/ptbench.toml``), and add a line like the following:
+
+       .. code:: toml
+
+          [datadir]
+          montgomery = "/path/to/montgomery/files"
+
+       .. note::
+
+          This setting **is** case-sensitive.
+
+\b
+    2. List all raw datasets supported (and configured):
+
+       .. code:: sh
+
+          $ ptbench dataset list
+
+""",
+)
+@verbosity_option(logger=logger, expose_value=False)
+def list():
+    """Lists all supported and configured datasets."""
+    installed = _get_installed_protocols()
+
+    click.echo("Available protocols:")
+    for k, v in installed.items():
+        click.echo(f'- {k}: "{v}"')
+
+
+@datamodule.command(
+    epilog="""Examples:
+
+    1. Check if all files of the Montgomery dataset can be loaded:
+
+       .. code:: sh
+
+          ptbench dataset check -vv shenzhen
+
+    2. Check if all files of multiple installed datasets can be loaded:
+
+       .. code:: sh
+
+          ptbench dataset check -vv montgomery shenzhen
+
+""",
+)
+@click.argument(
+    "protocols",
+    nargs=-1,
+)
+@click.option(
+    "--limit",
+    "-l",
+    help="Limit check to the first N samples in each dataset, making the "
+    "check sensibly faster.  Set it to zero to check everything.",
+    required=True,
+    type=click.IntRange(0),
+    default=0,
+)
+@verbosity_option(logger=logger, expose_value=False)
+def check(protocols, limit):
+    """Checks file access on one or more datamodules."""
+    import importlib
+
+    errors = 0
+    for protocol in protocols:
+        try:
+            module = importlib.metadata.entry_points(group="ptbench.config")[
+                protocol
+            ].module
+        except KeyError:
+            raise Exception(f"Could not find protocol {protocol}")
+
+        datamodule = importlib.import_module(module).datamodule
+
+        database_split = datamodule.database_split
+        raw_data_loader = datamodule.raw_data_loader
+
+        errors += check_database_split_loading(
+            database_split, raw_data_loader, limit=limit
+        )
+
+    if not errors:
+        click.echo("No errors reported")
diff --git a/src/ptbench/scripts/dataset.py b/src/ptbench/scripts/dataset.py
deleted file mode 100644
index 844311b5..00000000
--- a/src/ptbench/scripts/dataset.py
+++ /dev/null
@@ -1,186 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-from __future__ import annotations
-
-import importlib.metadata
-import importlib.resources
-import os
-
-import click
-
-from clapper.click import AliasedGroup, verbosity_option
-from clapper.logging import setup
-
-from ..data.split import check_database_split_loading
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-def _get_supported_datasets():
-    """Returns a list of supported dataset names."""
-    basedir = importlib.resources.files(__name__.split(".", 1)[0]).joinpath(
-        "data"
-    )
-
-    retval = []
-    for candidate in basedir.iterdir():
-        if candidate.is_dir() and "__init__.py" in os.listdir(str(candidate)):
-            retval.append(candidate.name)
-
-    return retval
-
-
-def _get_installed_datasets() -> dict[str, str]:
-    """Returns a list of installed datasets as regular expressions.
-
-    * group(0): the name of the key for the dataset directory
-    * group("name"): the short name for the dataset
-    """
-    from ..utils.rc import load_rc
-
-    return dict(load_rc().get("datadir", {}))
-
-
-@click.group(cls=AliasedGroup)
-def dataset() -> None:
-    """Commands for listing and verifying datasets."""
-    pass
-
-
-@dataset.command(
-    epilog="""Examples:
-
-\b
-    1. To install a dataset, set up its data directory ("datadir").  For
-       example, to setup access to Montgomery files you downloaded locally at
-       the directory "/path/to/montgomery/files", edit the RC file (typically
-       ``$HOME/.config/ptbench.toml``), and add a line like the following:
-
-       .. code:: toml
-
-          [datadir]
-          montgomery = "/path/to/montgomery/files"
-
-       .. note::
-
-          This setting **is** case-sensitive.
-
-\b
-    2. List all raw datasets supported (and configured):
-
-       .. code:: sh
-
-          $ ptbench dataset list
-
-""",
-)
-@verbosity_option(logger=logger, expose_value=False)
-def list():
-    """Lists all supported and configured datasets."""
-    supported = _get_supported_datasets()
-    installed = _get_installed_datasets()
-
-    click.echo("Supported datasets:")
-    for k in supported:
-        if k in installed:
-            click.echo(f'- {k}: "{installed[k]}"')
-        else:
-            click.echo(f"* {k}: datadir.{k} (not set)")
-
-
-@dataset.command(
-    epilog="""Examples:
-
-    1. Check if all files of the Montgomery dataset can be loaded:
-
-       .. code:: sh
-
-          ptbench dataset check -vv montgomery
-
-    2. Check if all files of multiple installed datasets can be loaded:
-
-       .. code:: sh
-
-          ptbench dataset check -vv montgomery shenzhen
-
-    3. Check if all files of all installed datasets can be loaded:
-
-       .. code:: sh
-
-          ptbench dataset check
-
-""",
-)
-@click.argument(
-    "dataset",
-    nargs=-1,
-)
-@click.option(
-    "--limit",
-    "-l",
-    help="Limit check to the first N samples in each dataset, making the "
-    "check sensibly faster.  Set it to zero to check everything.",
-    required=True,
-    type=click.IntRange(0),
-    default=0,
-)
-@verbosity_option(logger=logger, expose_value=False)
-def check(dataset, limit):
-    """Checks file access on one or more datasets."""
-    import importlib
-
-    to_check = _get_installed_datasets()
-
-    if dataset:  # check only some
-        delete = [k for k in to_check.keys() if k not in dataset]
-        for k in delete:
-            del to_check[k]
-
-    if not to_check:
-        click.secho(
-            "WARNING: No configured datasets matching specifications",
-            fg="yellow",
-            bold=True,
-        )
-        click.echo(
-            "Try ptbench dataset list --help to get help in "
-            "configuring a dataset"
-        )
-    else:
-        errors = 0
-        for k in to_check.keys():
-            click.echo(f'Checking "{k}" dataset...')
-
-            # Gathering protocols for the dataset
-            entrypoints = [
-                i
-                for i in importlib.metadata.entry_points(
-                    group="ptbench.config"
-                ).names
-                if i == k
-            ]
-            protocols_modules = sorted(
-                [
-                    importlib.metadata.entry_points(group="ptbench.config")[
-                        i
-                    ].module
-                    for i in entrypoints
-                ]
-            )
-
-            for protocol in protocols_modules:
-                datamodule = importlib.import_module(protocol).datamodule
-
-                database_split = datamodule.database_split
-                raw_data_loader = datamodule.raw_data_loader
-
-                logger.info(f"Checking protocol {protocol}")
-
-                errors += check_database_split_loading(
-                    database_split._subsets, raw_data_loader, limit=limit
-                )
-
-        if not errors:
-            click.echo("No errors reported")
-- 
GitLab