From 42f897493d1569cea3c1be87ff7bf05fc46e415e Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Tue, 8 Aug 2023 08:35:45 +0200 Subject: [PATCH] [cli] Refactor config and database (old datamodule) commands --- src/ptbench/scripts/cli.py | 4 +- src/ptbench/scripts/config.py | 11 +- src/ptbench/scripts/database.py | 195 ++++++++++++++++++++++++++++++ src/ptbench/scripts/datamodule.py | 152 ----------------------- tests/test_cli.py | 25 ++-- 5 files changed, 210 insertions(+), 177 deletions(-) create mode 100644 src/ptbench/scripts/database.py delete mode 100644 src/ptbench/scripts/datamodule.py diff --git a/src/ptbench/scripts/cli.py b/src/ptbench/scripts/cli.py index 479d6b6b..a84c9d5b 100644 --- a/src/ptbench/scripts/cli.py +++ b/src/ptbench/scripts/cli.py @@ -10,7 +10,7 @@ from . import ( aggregpred, compare, config, - datamodule, + database, evaluate, experiment, predict, @@ -32,7 +32,7 @@ def cli(): cli.add_command(aggregpred.aggregpred) cli.add_command(compare.compare) cli.add_command(config.config) -cli.add_command(datamodule.datamodule) +cli.add_command(database.database) cli.add_command(evaluate.evaluate) cli.add_command(experiment.experiment) cli.add_command(predict.predict) diff --git a/src/ptbench/scripts/config.py b/src/ptbench/scripts/config.py index a2d49dbf..77a0b6ed 100644 --- a/src/ptbench/scripts/config.py +++ b/src/ptbench/scripts/config.py @@ -51,22 +51,15 @@ def list(verbose) -> None: ) entry_point_dict = {k.name: k for k in entry_points} - # all modules with configuration resources + # all potential modules with configuration resources modules = {k.module.rsplit(".", 1)[0] for k in entry_point_dict.values()} - keep_modules: set[str] = set() - for k in sorted(modules): - if k not in keep_modules and not any( - k.startswith(element) for element in keep_modules - ): - keep_modules.add(k) - modules = keep_modules # sort data entries by originating module entry_points_by_module: dict[str, dict[str, typing.Any]] = {} for k in modules: entry_points_by_module[k] = {} for name, ep in entry_point_dict.items(): - if ep.module.startswith(k): + if ep.module.rsplit(".", 1)[0] == k: entry_points_by_module[k][name] = ep for config_type in sorted(entry_points_by_module): diff --git a/src/ptbench/scripts/database.py b/src/ptbench/scripts/database.py new file mode 100644 index 00000000..e73595a3 --- /dev/null +++ b/src/ptbench/scripts/database.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + +import click + +from clapper.click import AliasedGroup, verbosity_option +from clapper.logging import setup + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +def _get_raw_databases() -> dict[str, dict[str, str]]: + """Returns a list of all supported (raw) databases. + + Returns + ------- + d + Dictionary where keys are database names, and values are dictionaries + containing two string keys: + + * ``module``: the full Pythonic module name (e.g. + ``ptbench.data.montgomery``) + * ``datadir``: points to the user-configured data directory for the + current dataset, if set, or ``None`` otherwise. + """ + import importlib + import pkgutil + + from ..config import data + from ..utils.rc import load_rc + + user_configuration = load_rc() + + retval = {} + for k in pkgutil.iter_modules(data.__path__): + for j in pkgutil.iter_modules( + [next(iter(data.__path__)) + f"/{k.name}"] + ): + if j.name == "datamodule": + # this is a submodule that can read raw data files + module = importlib.import_module( + f".{j.name}", data.__package__ + f".{k.name}" + ) + if hasattr(module, "CONFIGURATION_KEY_DATADIR"): + retval[k.name] = dict( + module=module.__name__.rsplit(".", 1)[0], + datadir=user_configuration.get( + module.CONFIGURATION_KEY_DATADIR + ), + ) + else: + retval[k.name] = dict(module=module.__name__) + + return retval + + +@click.group(cls=AliasedGroup) +def database() -> None: + """Commands for listing and verifying databases installed.""" + pass + + +@database.command( + epilog="""Examples: + +\b + 1. To install a database, set up its data directory ("datadir"). For + example, to setup access to Montgomery files you downloaded locally at + the directory "/path/to/montgomery/files", edit the RC file (typically + ``$HOME/.config/ptbench.toml``), and add a line like the following: + + .. code:: toml + + [datadir] + montgomery = "/path/to/montgomery/files" + + .. note:: + + This setting **is** case-sensitive. + +\b + 2. List all raw databases supported (and configured): + + .. code:: sh + + $ ptbench database list + +""", +) +@verbosity_option(logger=logger, expose_value=False) +def list(): + """Lists all supported and configured databases.""" + config = _get_raw_databases() + + click.echo("Available databases:") + for k, v in config.items(): + if "datadir" not in v: + # this database does not have a "datadir" + continue + + if v["datadir"] is not None: + click.secho(f'- {k} ({v["module"]}): "{v["datadir"]}"', fg="green") + else: + click.echo(f'- {k} ({v["module"]}): NOT installed') + + +@database.command( + epilog="""Examples: + + 1. Check if all files from the split 'montgomery-f0' of the Montgomery + database can be loaded: + + .. code:: sh + + ptbench datamodule check -vv montgomery-f0 + +""", +) +@click.argument( + "split", + nargs=1, +) +@click.option( + "--limit", + "-l", + help="Limit check to the first N samples in each split dataset, making the " + "check sensibly faster. Set it to zero (default) to check everything.", + required=True, + type=click.IntRange(0), + default=0, +) +@verbosity_option(logger=logger, expose_value=False) +def check(split, limit): + """Checks file access on one or more datamodules.""" + import importlib.metadata + import sys + + click.secho(f"Checking split `{split}`...", fg="yellow") + try: + module = importlib.metadata.entry_points(group="ptbench.config")[ + split + ].module + except KeyError: + raise Exception(f"Could not find database split `{split}`") + + datamodule = importlib.import_module(module).datamodule + + datamodule.model_transforms = [] # should be done before setup() + datamodule.batch_size = 1 # ensure one sample is loaded at a time + datamodule.setup("predict") # sets up all datasets + + loaders = datamodule.predict_dataloader() + + errors = 0 + for k, loader in loaders.items(): + if limit == 0: + click.secho( + f"Checking all samples of dataset `{k}` at split `{split}`...", + fg="yellow", + ) + loader_limit = sys.maxsize + else: + click.secho( + f"Checking first {limit} samples of dataset " + f"`{k}` at split `{split}`...", + fg="yellow", + ) + loader_limit = limit + # the for loop will trigger raw data loading (ie. user code), protect + # it + try: + for i, batch in enumerate(loader): + if loader_limit == 0: + break + logger.info( + f"{batch[1]['name'][0]}: " + f"{[s for s in batch[0][0].shape]}@{batch[0][0].dtype}" + ) + loader_limit -= 1 + except Exception: + logger.exception(f"Unable to load batch {i} in dataset {k}") + errors += 1 + + if not errors: + click.secho( + f"OK! No errors were reported for database split `{split}`.", + fg="green", + ) + else: + click.secho( + f"Found {errors} errors loading datamodule `{split}`.", fg="red" + ) diff --git a/src/ptbench/scripts/datamodule.py b/src/ptbench/scripts/datamodule.py deleted file mode 100644 index 66681ea7..00000000 --- a/src/ptbench/scripts/datamodule.py +++ /dev/null @@ -1,152 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from __future__ import annotations - -import importlib.metadata -import importlib.resources - -import click - -from clapper.click import AliasedGroup, verbosity_option -from clapper.logging import setup - -from ..data.split import check_database_split_loading - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -def _get_installed_datamodules() -> dict[str, str]: - """Returns a list of all installed datamodules. - - Returns - ------- - - datamodules: - List of datamodules. - """ - entrypoints = sorted( - [ - entrypoint - for entrypoint in importlib.metadata.entry_points( - group="ptbench.config" - ).names - ] - ) - - datamodules = [ - importlib.metadata.entry_points(group="ptbench.config")[ - entrypoint - ].module - for entrypoint in entrypoints - ] - - datamodules_dict = { - entrypoints[i]: datamodules[i] - for i in range(len(entrypoints)) - if datamodules[i].split(".")[1] == "data" - } - - return datamodules_dict - - -@click.group(cls=AliasedGroup) -def datamodule() -> None: - """Commands for listing and verifying datamodules.""" - pass - - -@datamodule.command( - epilog="""Examples: - -\b - 1. To install a dataset, set up its data directory ("datadir"). For - example, to setup access to Montgomery files you downloaded locally at - the directory "/path/to/montgomery/files", edit the RC file (typically - ``$HOME/.config/ptbench.toml``), and add a line like the following: - - .. code:: toml - - [datadir] - montgomery = "/path/to/montgomery/files" - - .. note:: - - This setting **is** case-sensitive. - -\b - 2. List all raw datamodule supported (and configured): - - .. code:: sh - - $ ptbench datamodule list - -""", -) -@verbosity_option(logger=logger, expose_value=False) -def list(): - """Lists all supported and configured datamodules.""" - installed = _get_installed_datamodules() - - click.echo("Available datamodules:") - for k, v in installed.items(): - click.echo(f'- {k}: "{v}"') - - -@datamodule.command( - epilog="""Examples: - - 1. Check if all files from the fold_0 of the Montgomery database can be loaded: - - .. code:: sh - - ptbench datamodule check -vv montgomery_f0 - - 2. Check if all files of multiple installed protocols can be loaded: - - .. code:: sh - - ptbench datamodule check -vv montgomery shenzhen - -""", -) -@click.argument( - "protocols", - nargs=-1, -) -@click.option( - "--limit", - "-l", - help="Limit check to the first N samples in each datamodule, making the " - "check sensibly faster. Set it to zero to check everything.", - required=True, - type=click.IntRange(0), - default=0, -) -@verbosity_option(logger=logger, expose_value=False) -def check(protocols, limit): - """Checks file access on one or more datamodules.""" - import importlib - - errors = 0 - for protocol in protocols: - logger.info(f"Checking {protocol}") - try: - module = importlib.metadata.entry_points(group="ptbench.config")[ - protocol - ].module - except KeyError: - raise Exception(f"Could not find protocol {protocol}") - - datamodule = importlib.import_module(module).datamodule - - database_split = datamodule.database_split - raw_data_loader = datamodule.raw_data_loader - - errors += check_database_split_loading( - database_split, raw_data_loader, limit=limit - ) - - if not errors: - click.echo("No errors reported") diff --git a/tests/test_cli.py b/tests/test_cli.py index 04fa19e4..e1a53b12 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -55,7 +55,6 @@ def test_config_list_help(): _check_help(list) -@pytest.mark.skip(reason="Test need to be updated") def test_config_list(): from ptbench.scripts.config import list @@ -66,7 +65,6 @@ def test_config_list(): assert "module: ptbench.config.models" in result.output -@pytest.mark.skip(reason="Test need to be updated") def test_config_list_v(): from ptbench.scripts.config import list @@ -82,7 +80,6 @@ def test_config_describe_help(): _check_help(describe) -@pytest.mark.skip(reason="Test need to be updated") @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_config_describe_montgomery(): from ptbench.scripts.config import describe @@ -90,42 +87,42 @@ def test_config_describe_montgomery(): runner = CliRunner() result = runner.invoke(describe, ["montgomery"]) _assert_exit_0(result) - assert "montgomery dataset for TB detection" in result.output + assert "Montgomery datamodule for TB detection." in result.output -def test_datamodule_help(): - from ptbench.scripts.datamodule import datamodule +def test_database_help(): + from ptbench.scripts.database import database - _check_help(datamodule) + _check_help(database) def test_datamodule_list_help(): - from ptbench.scripts.datamodule import list + from ptbench.scripts.database import list _check_help(list) def test_datamodule_list(): - from ptbench.scripts.datamodule import list + from ptbench.scripts.database import list runner = CliRunner() result = runner.invoke(list) _assert_exit_0(result) - assert result.output.startswith("Available datamodules:") + assert result.output.startswith("Available databases:") def test_datamodule_check_help(): - from ptbench.scripts.datamodule import check + from ptbench.scripts.database import check _check_help(check) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_datamodule_check(): - from ptbench.scripts.datamodule import check +def test_database_check(): + from ptbench.scripts.database import check runner = CliRunner() - result = runner.invoke(check, ["--verbose", "--limit=2"]) + result = runner.invoke(check, ["--verbose", "--limit=1", "montgomery"]) _assert_exit_0(result) -- GitLab