From 129a17a1067d4a2f02ae0b13e70fc5d7439412cd Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Thu, 25 Apr 2024 10:00:37 +0200
Subject: [PATCH] [script] Move config and database scripts to common lib

---
 .../libs/classification/scripts/config.py     |  98 +------------
 .../libs/classification/scripts/database.py   |  74 +---------
 src/mednet/libs/common/scripts/config.py      | 115 ++++++++++++++++
 src/mednet/libs/common/scripts/database.py    |  98 +++++++++++++
 src/mednet/libs/segmentation/scripts/cli.py   |   8 +-
 .../libs/segmentation/scripts/config.py       | 107 +++++++++++++++
 .../libs/segmentation/scripts/database.py     | 129 ++++++++++++++++++
 7 files changed, 463 insertions(+), 166 deletions(-)
 create mode 100644 src/mednet/libs/common/scripts/config.py
 create mode 100644 src/mednet/libs/common/scripts/database.py
 create mode 100644 src/mednet/libs/segmentation/scripts/config.py
 create mode 100644 src/mednet/libs/segmentation/scripts/database.py

diff --git a/src/mednet/libs/classification/scripts/config.py b/src/mednet/libs/classification/scripts/config.py
index 3b48a197..9a41ea83 100644
--- a/src/mednet/libs/classification/scripts/config.py
+++ b/src/mednet/libs/classification/scripts/config.py
@@ -2,14 +2,12 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-import importlib.metadata
-import inspect
-import pathlib
-import typing
-
 import click
 from clapper.click import AliasedGroup, verbosity_option
 from clapper.logging import setup
+from mednet.libs.common.scripts.config import copy as copy_
+from mednet.libs.common.scripts.config import describe as describe_
+from mednet.libs.common.scripts.config import list_ as list__
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
@@ -45,54 +43,7 @@ def config():
 @verbosity_option(logger=logger)
 def list_(verbose) -> None:  # numpydoc ignore=PR01
     """List configuration files installed."""
-    entry_points = importlib.metadata.entry_points().select(
-        group="mednet.libs.classification.config",
-    )
-    entry_point_dict = {k.name: k for k in entry_points}
-
-    # all potential modules with configuration resources
-    modules = {k.module.rsplit(".", 1)[0] for k in entry_point_dict.values()}
-
-    # sort data entries by originating module
-    entry_points_by_module: dict[str, dict[str, typing.Any]] = {}
-    for k in modules:
-        entry_points_by_module[k] = {}
-        for name, ep in entry_point_dict.items():
-            if ep.module.rsplit(".", 1)[0] == k:
-                entry_points_by_module[k][name] = ep
-
-    for config_type in sorted(entry_points_by_module):
-        # calculates the longest config name so we offset the printing
-        longest_name_length = max(
-            len(k) for k in entry_points_by_module[config_type].keys()
-        )
-
-        # set-up printing options
-        print_string = "  %%-%ds   %%s" % (longest_name_length,)
-        # 79 - 4 spaces = 75 (see string above)
-        description_leftover = 75 - longest_name_length
-
-        click.echo(f"module: {config_type}")
-        for name in sorted(entry_points_by_module[config_type]):
-            ep = entry_point_dict[name]
-
-            if verbose >= 1:
-                module = ep.load()
-                doc = inspect.getdoc(module)
-                if doc is not None:
-                    summary = doc.split("\n\n")[0]
-                else:
-                    summary = "<DOCSTRING NOT AVAILABLE>"
-            else:
-                summary = ""
-
-            summary = (
-                (summary[: (description_leftover - 3)] + "...")
-                if len(summary) > (description_leftover - 3)
-                else summary
-            )
-
-            click.echo(print_string % (name, summary))
+    list__("mednet.libs.classification.config", verbose)
 
 
 @config.command(
@@ -124,29 +75,7 @@ def list_(verbose) -> None:  # numpydoc ignore=PR01
 @verbosity_option(logger=logger)
 def describe(name, verbose) -> None:  # numpydoc ignore=PR01
     """Describe a specific configuration file."""
-    entry_points = importlib.metadata.entry_points().select(
-        group="mednet.libs.classification.config",
-    )
-    entry_point_dict = {k.name: k for k in entry_points}
-
-    for k in name:
-        if k not in entry_point_dict:
-            logger.error("Cannot find configuration resource '%s'", k)
-            continue
-        ep = entry_point_dict[k]
-        click.echo(f"Configuration: {ep.name}")
-        click.echo(f"Python Module: {ep.module}")
-        click.echo("")
-        mod = ep.load()
-
-        if verbose >= 1:
-            fname = inspect.getfile(mod)
-            click.echo("Contents:")
-            with pathlib.Path(fname).open() as f:
-                click.echo(f.read())
-        else:  # only output documentation
-            click.echo("Documentation:")
-            click.echo(inspect.getdoc(mod))
+    describe_(name, "mednet.libs.classification.config", verbose)
 
 
 @config.command(
@@ -175,19 +104,4 @@ def describe(name, verbose) -> None:  # numpydoc ignore=PR01
 @verbosity_option(logger=logger, expose_value=False)
 def copy(source, destination) -> None:  # numpydoc ignore=PR01
     """Copy a specific configuration resource so it can be modified locally."""
-    import shutil
-
-    entry_points = importlib.metadata.entry_points().select(
-        group="mednet.libs.classification.config",
-    )
-    entry_point_dict = {k.name: k for k in entry_points}
-
-    if source not in entry_point_dict:
-        logger.error("Cannot find configuration resource '%s'", source)
-        return
-
-    ep = entry_point_dict[source]
-    mod = ep.load()
-    src_name = inspect.getfile(mod)
-    logger.info(f"cp {src_name} -> {destination}")
-    shutil.copyfile(src_name, destination)
+    copy_(source, destination, "mednet.libs.classification.config")
diff --git a/src/mednet/libs/classification/scripts/database.py b/src/mednet/libs/classification/scripts/database.py
index d918ad6d..b76f2e69 100644
--- a/src/mednet/libs/classification/scripts/database.py
+++ b/src/mednet/libs/classification/scripts/database.py
@@ -5,6 +5,8 @@
 import click
 from clapper.click import AliasedGroup, verbosity_option
 from clapper.logging import setup
+from mednet.libs.common.scripts.database import check as check_
+from mednet.libs.common.scripts.database import list_ as list__
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
@@ -93,18 +95,8 @@ def database() -> None:
 @verbosity_option(logger=logger, expose_value=False)
 def list_():
     """List all supported and configured databases."""
-    config = _get_raw_databases()
 
-    click.echo("Available databases:")
-    for k, v in config.items():
-        if "datadir" not in v:
-            # this database does not have a "datadir"
-            continue
-
-        if v["datadir"] is not None:
-            click.secho(f'- {k} ({v["module"]}): "{v["datadir"]}"', fg="green")
-        else:
-            click.echo(f'- {k} ({v["module"]}): NOT installed')
+    list__(_get_raw_databases())
 
 
 @database.command(
@@ -135,62 +127,4 @@ def list_():
 @verbosity_option(logger=logger, expose_value=False)
 def check(split, limit):  # numpydoc ignore=PR01
     """Check file access on one or more DataModules."""
-    import importlib.metadata
-    import sys
-
-    click.secho(f"Checking split `{split}`...", fg="yellow")
-    try:
-        module = importlib.metadata.entry_points(
-            group="mednet.libs.classification.config"
-        )[split].module
-    except KeyError:
-        raise Exception(f"Could not find database split `{split}`")
-
-    datamodule = importlib.import_module(module).datamodule
-
-    datamodule.model_transforms = []  # should be done before setup()
-    datamodule.batch_size = 1  # ensure one sample is loaded at a time
-    datamodule.setup("predict")  # sets up all datasets
-
-    loaders = datamodule.predict_dataloader()
-
-    errors = 0
-    for k, loader in loaders.items():
-        if limit == 0:
-            click.secho(
-                f"Checking all samples of dataset `{k}` at split `{split}`...",
-                fg="yellow",
-            )
-            loader_limit = sys.maxsize
-        else:
-            click.secho(
-                f"Checking first {limit} samples of dataset "
-                f"`{k}` at split `{split}`...",
-                fg="yellow",
-            )
-            loader_limit = limit
-        # the for loop will trigger raw data loading (ie. user code), protect
-        # it
-        try:
-            for i, batch in enumerate(loader):
-                if loader_limit == 0:
-                    break
-                logger.info(
-                    f"{batch[1]['name'][0]}: "
-                    f"{[s for s in batch[0][0].shape]}@{batch[0][0].dtype}",
-                )
-                loader_limit -= 1
-        except Exception:
-            logger.exception(f"Unable to load batch {i} in dataset {k}")
-            errors += 1
-
-    if not errors:
-        click.secho(
-            f"OK! No errors were reported for database split `{split}`.",
-            fg="green",
-        )
-    else:
-        click.secho(
-            f"Found {errors} errors loading DataModule `{split}`.",
-            fg="red",
-        )
+    check_("mednet.libs.classification.config", split, limit)
diff --git a/src/mednet/libs/common/scripts/config.py b/src/mednet/libs/common/scripts/config.py
new file mode 100644
index 00000000..a813e90f
--- /dev/null
+++ b/src/mednet/libs/common/scripts/config.py
@@ -0,0 +1,115 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import importlib.metadata
+import inspect
+import pathlib
+import typing
+
+import click
+from clapper.logging import setup
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+def list_(entry_point_group, verbose) -> None:  # numpydoc ignore=PR01
+    """List configuration files installed."""
+
+    entry_points = importlib.metadata.entry_points().select(
+        group=entry_point_group,
+    )
+    entry_point_dict = {k.name: k for k in entry_points}
+
+    # all potential modules with configuration resources
+    modules = {k.module.rsplit(".", 1)[0] for k in entry_point_dict.values()}
+
+    # sort data entries by originating module
+    entry_points_by_module: dict[str, dict[str, typing.Any]] = {}
+    for k in modules:
+        entry_points_by_module[k] = {}
+        for name, ep in entry_point_dict.items():
+            if ep.module.rsplit(".", 1)[0] == k:
+                entry_points_by_module[k][name] = ep
+
+    for config_type in sorted(entry_points_by_module):
+        # calculates the longest config name so we offset the printing
+        longest_name_length = max(
+            len(k) for k in entry_points_by_module[config_type].keys()
+        )
+
+        # set-up printing options
+        print_string = "  %%-%ds   %%s" % (longest_name_length,)
+        # 79 - 4 spaces = 75 (see string above)
+        description_leftover = 75 - longest_name_length
+
+        click.echo(f"module: {config_type}")
+        for name in sorted(entry_points_by_module[config_type]):
+            ep = entry_point_dict[name]
+
+            if verbose >= 1:
+                module = ep.load()
+                doc = inspect.getdoc(module)
+                if doc is not None:
+                    summary = doc.split("\n\n")[0]
+                else:
+                    summary = "<DOCSTRING NOT AVAILABLE>"
+            else:
+                summary = ""
+
+            summary = (
+                (summary[: (description_leftover - 3)] + "...")
+                if len(summary) > (description_leftover - 3)
+                else summary
+            )
+
+            click.echo(print_string % (name, summary))
+
+
+def describe(name, entry_point_group, verbose) -> None:  # numpydoc ignore=PR01
+    """Describe a specific configuration file."""
+    entry_points = importlib.metadata.entry_points().select(
+        group=entry_point_group,
+    )
+    entry_point_dict = {k.name: k for k in entry_points}
+
+    for k in name:
+        if k not in entry_point_dict:
+            logger.error("Cannot find configuration resource '%s'", k)
+            continue
+        ep = entry_point_dict[k]
+        click.echo(f"Configuration: {ep.name}")
+        click.echo(f"Python Module: {ep.module}")
+        click.echo("")
+        mod = ep.load()
+
+        if verbose >= 1:
+            fname = inspect.getfile(mod)
+            click.echo("Contents:")
+            with pathlib.Path(fname).open() as f:
+                click.echo(f.read())
+        else:  # only output documentation
+            click.echo("Documentation:")
+            click.echo(inspect.getdoc(mod))
+
+
+def copy(
+    source, destination, entry_point_group
+) -> None:  # numpydoc ignore=PR01
+    """Copy a specific configuration resource so it can be modified locally."""
+    import shutil
+
+    entry_points = importlib.metadata.entry_points().select(
+        group=entry_point_group,
+    )
+    entry_point_dict = {k.name: k for k in entry_points}
+
+    if source not in entry_point_dict:
+        logger.error("Cannot find configuration resource '%s'", source)
+        return
+
+    ep = entry_point_dict[source]
+    mod = ep.load()
+    src_name = inspect.getfile(mod)
+    logger.info(f"cp {src_name} -> {destination}")
+    shutil.copyfile(src_name, destination)
diff --git a/src/mednet/libs/common/scripts/database.py b/src/mednet/libs/common/scripts/database.py
new file mode 100644
index 00000000..755e6afd
--- /dev/null
+++ b/src/mednet/libs/common/scripts/database.py
@@ -0,0 +1,98 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import click
+from clapper.logging import setup
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+def list_(config: dict[str, dict[str, str]]) -> None:
+    """List all supported and configured databases.
+
+    Parameters
+    ----------
+    config
+        Dictionary where keys are database names, and values are dictionaries
+        containing two string keys:
+
+        * ``module``: the full Pythonic module name (e.g.
+        ``mednet.libs.classification.data.montgomery``).
+        * ``datadir``: points to the user-configured data directory for the
+        current dataset, if set, or ``None`` otherwise.
+    """
+
+    click.echo("Available databases:")
+    for k, v in config.items():
+        if "datadir" not in v:
+            # this database does not have a "datadir"
+            continue
+
+        if v["datadir"] is not None:
+            click.secho(f'- {k} ({v["module"]}): "{v["datadir"]}"', fg="green")
+        else:
+            click.echo(f'- {k} ({v["module"]}): NOT installed')
+
+
+def check(entry_point_group, split, limit):  # numpydoc ignore=PR01
+    """Check file access on one or more DataModules."""
+    import importlib.metadata
+    import sys
+
+    click.secho(f"Checking split `{split}`...", fg="yellow")
+    try:
+        module = importlib.metadata.entry_points(group=entry_point_group)[
+            split
+        ].module
+    except KeyError:
+        raise Exception(f"Could not find database split `{split}`")
+
+    datamodule = importlib.import_module(module).datamodule
+
+    datamodule.model_transforms = []  # should be done before setup()
+    datamodule.batch_size = 1  # ensure one sample is loaded at a time
+    datamodule.setup("predict")  # sets up all datasets
+
+    loaders = datamodule.predict_dataloader()
+
+    errors = 0
+    for k, loader in loaders.items():
+        if limit == 0:
+            click.secho(
+                f"Checking all samples of dataset `{k}` at split `{split}`...",
+                fg="yellow",
+            )
+            loader_limit = sys.maxsize
+        else:
+            click.secho(
+                f"Checking first {limit} samples of dataset "
+                f"`{k}` at split `{split}`...",
+                fg="yellow",
+            )
+            loader_limit = limit
+        # the for loop will trigger raw data loading (ie. user code), protect
+        # it
+        try:
+            for i, batch in enumerate(loader):
+                if loader_limit == 0:
+                    break
+                logger.info(
+                    f"{batch[1]['name'][0]}: "
+                    f"{[s for s in batch[0][0].shape]}@{batch[0][0].dtype}",
+                )
+                loader_limit -= 1
+        except Exception:
+            logger.exception(f"Unable to load batch {i} in dataset {k}")
+            errors += 1
+
+    if not errors:
+        click.secho(
+            f"OK! No errors were reported for database split `{split}`.",
+            fg="green",
+        )
+    else:
+        click.secho(
+            f"Found {errors} errors loading DataModule `{split}`.",
+            fg="red",
+        )
diff --git a/src/mednet/libs/segmentation/scripts/cli.py b/src/mednet/libs/segmentation/scripts/cli.py
index add5fdac..82aaed41 100644
--- a/src/mednet/libs/segmentation/scripts/cli.py
+++ b/src/mednet/libs/segmentation/scripts/cli.py
@@ -8,8 +8,8 @@ from clapper.click import AliasedGroup
 from . import (
     # analyze,
     # compare,
-    # config,
-    # dataset,
+    config,
+    database,
     # evaluate,
     # experiment,
     # mkmask,
@@ -30,8 +30,8 @@ def segmentation():
 
 # segmentation.add_command(analyze.analyze)
 # segmentation.add_command(compare.compare)
-# segmentation.add_command(config.config)
-# segmentation.add_command(dataset.dataset)
+segmentation.add_command(config.config)
+segmentation.add_command(database.database)
 # segmentation.add_command(evaluate.evaluate)
 # segmentation.add_command(experiment.experiment)
 # segmentation.add_command(mkmask.mkmask)
diff --git a/src/mednet/libs/segmentation/scripts/config.py b/src/mednet/libs/segmentation/scripts/config.py
new file mode 100644
index 00000000..b900ba86
--- /dev/null
+++ b/src/mednet/libs/segmentation/scripts/config.py
@@ -0,0 +1,107 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import click
+from clapper.click import AliasedGroup, verbosity_option
+from clapper.logging import setup
+from mednet.libs.common.scripts.config import copy as copy_
+from mednet.libs.common.scripts.config import describe as describe_
+from mednet.libs.common.scripts.config import list_ as list__
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+@click.group(cls=AliasedGroup)
+def config():
+    """Command for listing, describing and copying configuration resources."""
+    pass
+
+
+@config.command(
+    name="list",
+    epilog="""Examples:
+
+\b
+  1. Lists all configuration resources (type: mednet.libs.classification.config) installed:
+
+     .. code:: sh
+
+        mednet config list
+
+
+\b
+  2. Lists all configuration resources and their descriptions (notice this may
+     be slow as it needs to load all modules once):
+
+     .. code:: sh
+
+        mednet config list -v
+
+""",
+)
+@verbosity_option(logger=logger)
+def list_(verbose) -> None:  # numpydoc ignore=PR01
+    """List configuration files installed."""
+    list__("mednet.libs.segmentation.config", verbose)
+
+
+@config.command(
+    epilog="""Examples:
+
+\b
+  1. Describe the Montgomery dataset configuration:
+
+     .. code:: sh
+
+        mednet config describe montgomery
+
+
+\b
+  2. Describe the Montgomery dataset configuration and lists its
+     contents:
+
+     .. code:: sh
+
+        mednet config describe montgomery -v
+
+""",
+)
+@click.argument(
+    "name",
+    required=True,
+    nargs=-1,
+)
+@verbosity_option(logger=logger)
+def describe(name, verbose) -> None:  # numpydoc ignore=PR01
+    """Describe a specific configuration file."""
+    describe_(name, "mednet.libs.segmentation.config", verbose)
+
+
+@config.command(
+    epilog="""Examples:
+
+\b
+  1. Make a copy of one of the stock configuration files locally, so it can be
+     adapted:
+
+     .. code:: sh
+
+        $ mednet config copy montgomery -vvv newdataset.py
+
+""",
+)
+@click.argument(
+    "source",
+    required=True,
+    nargs=1,
+)
+@click.argument(
+    "destination",
+    required=True,
+    nargs=1,
+)
+@verbosity_option(logger=logger, expose_value=False)
+def copy(source, destination) -> None:  # numpydoc ignore=PR01
+    """Copy a specific configuration resource so it can be modified locally."""
+    copy_(source, destination, "mednet.libs.segmentation.config")
diff --git a/src/mednet/libs/segmentation/scripts/database.py b/src/mednet/libs/segmentation/scripts/database.py
new file mode 100644
index 00000000..2d89b4f5
--- /dev/null
+++ b/src/mednet/libs/segmentation/scripts/database.py
@@ -0,0 +1,129 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import click
+from clapper.click import AliasedGroup, verbosity_option
+from clapper.logging import setup
+from mednet.libs.common.scripts.database import check as check_
+from mednet.libs.common.scripts.database import list_ as list__
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+def _get_raw_databases() -> dict[str, dict[str, str]]:
+    """Return a list of all supported (raw) databases.
+
+    Returns
+    -------
+    dict[str, dict[str, str]]
+        Dictionary where keys are database names, and values are dictionaries
+        containing two string keys:
+
+        * ``module``: the full Pythonic module name (e.g.
+        ``mednet.libs.classification.data.montgomery``).
+        * ``datadir``: points to the user-configured data directory for the
+        current dataset, if set, or ``None`` otherwise.
+    """
+
+    import importlib
+    import pkgutil
+
+    from ..config import data
+    from ..utils.rc import load_rc
+
+    user_configuration = load_rc()
+
+    retval = {}
+    for k in pkgutil.iter_modules(data.__path__):
+        for j in pkgutil.iter_modules(
+            [next(iter(data.__path__)) + f"/{k.name}"],
+        ):
+            if j.name == "datamodule":
+                # this is a submodule that can read raw data files
+                module = importlib.import_module(
+                    f".{j.name}",
+                    data.__package__ + f".{k.name}",
+                )
+                if hasattr(module, "CONFIGURATION_KEY_DATADIR"):
+                    retval[k.name] = dict(
+                        module=module.__name__.rsplit(".", 1)[0],
+                        datadir=user_configuration.get(
+                            module.CONFIGURATION_KEY_DATADIR,
+                        ),
+                    )
+                else:
+                    retval[k.name] = dict(module=module.__name__)
+
+    return retval
+
+
+@click.group(cls=AliasedGroup)
+def database() -> None:
+    """Command for listing and verifying databases installed."""
+    pass
+
+
+@database.command(
+    name="list",
+    epilog="""Examples:
+
+\b
+    1. To install a database, set up its data directory ("datadir").  For
+       example, to setup access to Montgomery files you downloaded locally at
+       the directory "/path/to/montgomery/files", edit the RC file (typically
+       ``$HOME/.config/mednet.libs.classification.toml``), and add a line like the following:
+
+       .. code:: toml
+
+          [datadir]
+          montgomery = "/path/to/montgomery/files"
+
+       .. note::
+
+          This setting **is** case-sensitive.
+
+\b
+    2. List all raw databases supported (and configured):
+
+       .. code:: sh
+
+          $ mednet database list
+
+""",
+)
+@verbosity_option(logger=logger, expose_value=False)
+def list_():
+    """List all supported and configured databases."""
+    list__(_get_raw_databases())
+
+
+@database.command(
+    epilog="""Examples:
+
+    1. Check if all files from the split 'montgomery-f0' of the Montgomery
+       database can be loaded:
+
+       .. code:: sh
+
+          mednet datamodule check -vv montgomery-f0
+
+""",
+)
+@click.argument(
+    "split",
+    nargs=1,
+)
+@click.option(
+    "--limit",
+    "-l",
+    help="Limit check to the first N samples in each split dataset, making the "
+    "check sensibly faster.  Set it to zero (default) to check everything.",
+    required=True,
+    type=click.IntRange(0),
+    default=0,
+)
+@verbosity_option(logger=logger, expose_value=False)
+def check(split, limit):  # numpydoc ignore=PR01
+    """Check file access on one or more DataModules."""
+    check_("mednet.libs.segmentation.config", split, limit)
-- 
GitLab