Skip to content
Snippets Groups Projects
Commit 42f89749 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[cli] Refactor config and database (old datamodule) commands

parent 331ad40c
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -10,7 +10,7 @@ from . import ( ...@@ -10,7 +10,7 @@ from . import (
aggregpred, aggregpred,
compare, compare,
config, config,
datamodule, database,
evaluate, evaluate,
experiment, experiment,
predict, predict,
...@@ -32,7 +32,7 @@ def cli(): ...@@ -32,7 +32,7 @@ def cli():
cli.add_command(aggregpred.aggregpred) cli.add_command(aggregpred.aggregpred)
cli.add_command(compare.compare) cli.add_command(compare.compare)
cli.add_command(config.config) cli.add_command(config.config)
cli.add_command(datamodule.datamodule) cli.add_command(database.database)
cli.add_command(evaluate.evaluate) cli.add_command(evaluate.evaluate)
cli.add_command(experiment.experiment) cli.add_command(experiment.experiment)
cli.add_command(predict.predict) cli.add_command(predict.predict)
......
...@@ -51,22 +51,15 @@ def list(verbose) -> None: ...@@ -51,22 +51,15 @@ def list(verbose) -> None:
) )
entry_point_dict = {k.name: k for k in entry_points} entry_point_dict = {k.name: k for k in entry_points}
# all modules with configuration resources # all potential modules with configuration resources
modules = {k.module.rsplit(".", 1)[0] for k in entry_point_dict.values()} modules = {k.module.rsplit(".", 1)[0] for k in entry_point_dict.values()}
keep_modules: set[str] = set()
for k in sorted(modules):
if k not in keep_modules and not any(
k.startswith(element) for element in keep_modules
):
keep_modules.add(k)
modules = keep_modules
# sort data entries by originating module # sort data entries by originating module
entry_points_by_module: dict[str, dict[str, typing.Any]] = {} entry_points_by_module: dict[str, dict[str, typing.Any]] = {}
for k in modules: for k in modules:
entry_points_by_module[k] = {} entry_points_by_module[k] = {}
for name, ep in entry_point_dict.items(): for name, ep in entry_point_dict.items():
if ep.module.startswith(k): if ep.module.rsplit(".", 1)[0] == k:
entry_points_by_module[k][name] = ep entry_points_by_module[k][name] = ep
for config_type in sorted(entry_points_by_module): for config_type in sorted(entry_points_by_module):
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from __future__ import annotations
import click
from clapper.click import AliasedGroup, verbosity_option
from clapper.logging import setup
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _get_raw_databases() -> dict[str, dict[str, str]]:
"""Returns a list of all supported (raw) databases.
Returns
-------
d
Dictionary where keys are database names, and values are dictionaries
containing two string keys:
* ``module``: the full Pythonic module name (e.g.
``ptbench.data.montgomery``)
* ``datadir``: points to the user-configured data directory for the
current dataset, if set, or ``None`` otherwise.
"""
import importlib
import pkgutil
from ..config import data
from ..utils.rc import load_rc
user_configuration = load_rc()
retval = {}
for k in pkgutil.iter_modules(data.__path__):
for j in pkgutil.iter_modules(
[next(iter(data.__path__)) + f"/{k.name}"]
):
if j.name == "datamodule":
# this is a submodule that can read raw data files
module = importlib.import_module(
f".{j.name}", data.__package__ + f".{k.name}"
)
if hasattr(module, "CONFIGURATION_KEY_DATADIR"):
retval[k.name] = dict(
module=module.__name__.rsplit(".", 1)[0],
datadir=user_configuration.get(
module.CONFIGURATION_KEY_DATADIR
),
)
else:
retval[k.name] = dict(module=module.__name__)
return retval
@click.group(cls=AliasedGroup)
def database() -> None:
"""Commands for listing and verifying databases installed."""
pass
@database.command(
epilog="""Examples:
\b
1. To install a database, set up its data directory ("datadir"). For
example, to setup access to Montgomery files you downloaded locally at
the directory "/path/to/montgomery/files", edit the RC file (typically
``$HOME/.config/ptbench.toml``), and add a line like the following:
.. code:: toml
[datadir]
montgomery = "/path/to/montgomery/files"
.. note::
This setting **is** case-sensitive.
\b
2. List all raw databases supported (and configured):
.. code:: sh
$ ptbench database list
""",
)
@verbosity_option(logger=logger, expose_value=False)
def list():
"""Lists all supported and configured databases."""
config = _get_raw_databases()
click.echo("Available databases:")
for k, v in config.items():
if "datadir" not in v:
# this database does not have a "datadir"
continue
if v["datadir"] is not None:
click.secho(f'- {k} ({v["module"]}): "{v["datadir"]}"', fg="green")
else:
click.echo(f'- {k} ({v["module"]}): NOT installed')
@database.command(
epilog="""Examples:
1. Check if all files from the split 'montgomery-f0' of the Montgomery
database can be loaded:
.. code:: sh
ptbench datamodule check -vv montgomery-f0
""",
)
@click.argument(
"split",
nargs=1,
)
@click.option(
"--limit",
"-l",
help="Limit check to the first N samples in each split dataset, making the "
"check sensibly faster. Set it to zero (default) to check everything.",
required=True,
type=click.IntRange(0),
default=0,
)
@verbosity_option(logger=logger, expose_value=False)
def check(split, limit):
"""Checks file access on one or more datamodules."""
import importlib.metadata
import sys
click.secho(f"Checking split `{split}`...", fg="yellow")
try:
module = importlib.metadata.entry_points(group="ptbench.config")[
split
].module
except KeyError:
raise Exception(f"Could not find database split `{split}`")
datamodule = importlib.import_module(module).datamodule
datamodule.model_transforms = [] # should be done before setup()
datamodule.batch_size = 1 # ensure one sample is loaded at a time
datamodule.setup("predict") # sets up all datasets
loaders = datamodule.predict_dataloader()
errors = 0
for k, loader in loaders.items():
if limit == 0:
click.secho(
f"Checking all samples of dataset `{k}` at split `{split}`...",
fg="yellow",
)
loader_limit = sys.maxsize
else:
click.secho(
f"Checking first {limit} samples of dataset "
f"`{k}` at split `{split}`...",
fg="yellow",
)
loader_limit = limit
# the for loop will trigger raw data loading (ie. user code), protect
# it
try:
for i, batch in enumerate(loader):
if loader_limit == 0:
break
logger.info(
f"{batch[1]['name'][0]}: "
f"{[s for s in batch[0][0].shape]}@{batch[0][0].dtype}"
)
loader_limit -= 1
except Exception:
logger.exception(f"Unable to load batch {i} in dataset {k}")
errors += 1
if not errors:
click.secho(
f"OK! No errors were reported for database split `{split}`.",
fg="green",
)
else:
click.secho(
f"Found {errors} errors loading datamodule `{split}`.", fg="red"
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from __future__ import annotations
import importlib.metadata
import importlib.resources
import click
from clapper.click import AliasedGroup, verbosity_option
from clapper.logging import setup
from ..data.split import check_database_split_loading
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _get_installed_datamodules() -> dict[str, str]:
"""Returns a list of all installed datamodules.
Returns
-------
datamodules:
List of datamodules.
"""
entrypoints = sorted(
[
entrypoint
for entrypoint in importlib.metadata.entry_points(
group="ptbench.config"
).names
]
)
datamodules = [
importlib.metadata.entry_points(group="ptbench.config")[
entrypoint
].module
for entrypoint in entrypoints
]
datamodules_dict = {
entrypoints[i]: datamodules[i]
for i in range(len(entrypoints))
if datamodules[i].split(".")[1] == "data"
}
return datamodules_dict
@click.group(cls=AliasedGroup)
def datamodule() -> None:
"""Commands for listing and verifying datamodules."""
pass
@datamodule.command(
epilog="""Examples:
\b
1. To install a dataset, set up its data directory ("datadir"). For
example, to setup access to Montgomery files you downloaded locally at
the directory "/path/to/montgomery/files", edit the RC file (typically
``$HOME/.config/ptbench.toml``), and add a line like the following:
.. code:: toml
[datadir]
montgomery = "/path/to/montgomery/files"
.. note::
This setting **is** case-sensitive.
\b
2. List all raw datamodule supported (and configured):
.. code:: sh
$ ptbench datamodule list
""",
)
@verbosity_option(logger=logger, expose_value=False)
def list():
"""Lists all supported and configured datamodules."""
installed = _get_installed_datamodules()
click.echo("Available datamodules:")
for k, v in installed.items():
click.echo(f'- {k}: "{v}"')
@datamodule.command(
epilog="""Examples:
1. Check if all files from the fold_0 of the Montgomery database can be loaded:
.. code:: sh
ptbench datamodule check -vv montgomery_f0
2. Check if all files of multiple installed protocols can be loaded:
.. code:: sh
ptbench datamodule check -vv montgomery shenzhen
""",
)
@click.argument(
"protocols",
nargs=-1,
)
@click.option(
"--limit",
"-l",
help="Limit check to the first N samples in each datamodule, making the "
"check sensibly faster. Set it to zero to check everything.",
required=True,
type=click.IntRange(0),
default=0,
)
@verbosity_option(logger=logger, expose_value=False)
def check(protocols, limit):
"""Checks file access on one or more datamodules."""
import importlib
errors = 0
for protocol in protocols:
logger.info(f"Checking {protocol}")
try:
module = importlib.metadata.entry_points(group="ptbench.config")[
protocol
].module
except KeyError:
raise Exception(f"Could not find protocol {protocol}")
datamodule = importlib.import_module(module).datamodule
database_split = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader
errors += check_database_split_loading(
database_split, raw_data_loader, limit=limit
)
if not errors:
click.echo("No errors reported")
...@@ -55,7 +55,6 @@ def test_config_list_help(): ...@@ -55,7 +55,6 @@ def test_config_list_help():
_check_help(list) _check_help(list)
@pytest.mark.skip(reason="Test need to be updated")
def test_config_list(): def test_config_list():
from ptbench.scripts.config import list from ptbench.scripts.config import list
...@@ -66,7 +65,6 @@ def test_config_list(): ...@@ -66,7 +65,6 @@ def test_config_list():
assert "module: ptbench.config.models" in result.output assert "module: ptbench.config.models" in result.output
@pytest.mark.skip(reason="Test need to be updated")
def test_config_list_v(): def test_config_list_v():
from ptbench.scripts.config import list from ptbench.scripts.config import list
...@@ -82,7 +80,6 @@ def test_config_describe_help(): ...@@ -82,7 +80,6 @@ def test_config_describe_help():
_check_help(describe) _check_help(describe)
@pytest.mark.skip(reason="Test need to be updated")
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_config_describe_montgomery(): def test_config_describe_montgomery():
from ptbench.scripts.config import describe from ptbench.scripts.config import describe
...@@ -90,42 +87,42 @@ def test_config_describe_montgomery(): ...@@ -90,42 +87,42 @@ def test_config_describe_montgomery():
runner = CliRunner() runner = CliRunner()
result = runner.invoke(describe, ["montgomery"]) result = runner.invoke(describe, ["montgomery"])
_assert_exit_0(result) _assert_exit_0(result)
assert "montgomery dataset for TB detection" in result.output assert "Montgomery datamodule for TB detection." in result.output
def test_datamodule_help(): def test_database_help():
from ptbench.scripts.datamodule import datamodule from ptbench.scripts.database import database
_check_help(datamodule) _check_help(database)
def test_datamodule_list_help(): def test_datamodule_list_help():
from ptbench.scripts.datamodule import list from ptbench.scripts.database import list
_check_help(list) _check_help(list)
def test_datamodule_list(): def test_datamodule_list():
from ptbench.scripts.datamodule import list from ptbench.scripts.database import list
runner = CliRunner() runner = CliRunner()
result = runner.invoke(list) result = runner.invoke(list)
_assert_exit_0(result) _assert_exit_0(result)
assert result.output.startswith("Available datamodules:") assert result.output.startswith("Available databases:")
def test_datamodule_check_help(): def test_datamodule_check_help():
from ptbench.scripts.datamodule import check from ptbench.scripts.database import check
_check_help(check) _check_help(check)
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_datamodule_check(): def test_database_check():
from ptbench.scripts.datamodule import check from ptbench.scripts.database import check
runner = CliRunner() runner = CliRunner()
result = runner.invoke(check, ["--verbose", "--limit=2"]) result = runner.invoke(check, ["--verbose", "--limit=1", "montgomery"])
_assert_exit_0(result) _assert_exit_0(result)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment