From 6d819c3d790b4ce675a1f01dfaca9d4e6188e16a Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 25 Jul 2023 16:43:23 +0200 Subject: [PATCH] Rename dataset script to datamodule and update --- src/ptbench/scripts/cli.py | 6 +- src/ptbench/scripts/datamodule.py | 149 ++++++++++++++++++++++++ src/ptbench/scripts/dataset.py | 186 ------------------------------ 3 files changed, 153 insertions(+), 188 deletions(-) create mode 100644 src/ptbench/scripts/datamodule.py delete mode 100644 src/ptbench/scripts/dataset.py diff --git a/src/ptbench/scripts/cli.py b/src/ptbench/scripts/cli.py index 1be33f6e..479d6b6b 100644 --- a/src/ptbench/scripts/cli.py +++ b/src/ptbench/scripts/cli.py @@ -10,8 +10,9 @@ from . import ( aggregpred, compare, config, - dataset, + datamodule, evaluate, + experiment, predict, predtojson, train, @@ -31,8 +32,9 @@ def cli(): cli.add_command(aggregpred.aggregpred) cli.add_command(compare.compare) cli.add_command(config.config) -cli.add_command(dataset.dataset) +cli.add_command(datamodule.datamodule) cli.add_command(evaluate.evaluate) +cli.add_command(experiment.experiment) cli.add_command(predict.predict) cli.add_command(predtojson.predtojson) cli.add_command(train.train) diff --git a/src/ptbench/scripts/datamodule.py b/src/ptbench/scripts/datamodule.py new file mode 100644 index 00000000..6cadcc72 --- /dev/null +++ b/src/ptbench/scripts/datamodule.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + +import importlib.metadata +import importlib.resources + +import click + +from clapper.click import AliasedGroup, verbosity_option +from clapper.logging import setup + +from ..data.split import check_database_split_loading + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +def _get_installed_protocols() -> dict[str, str]: + """Returns a list of all installed protocols. + + Returns + ------- + + protocols: + List of protocols. + """ + entrypoints = sorted( + [ + entrypoint + for entrypoint in importlib.metadata.entry_points( + group="ptbench.config" + ).names + ] + ) + + protocols = [ + importlib.metadata.entry_points(group="ptbench.config")[ + entrypoint + ].module + for entrypoint in entrypoints + ] + + protocols_dict = { + entrypoints[i]: protocols[i] for i in range(len(entrypoints)) + } + + return protocols_dict + + +@click.group(cls=AliasedGroup) +def datamodule() -> None: + """Commands for listing and verifying datamodules.""" + pass + + +@datamodule.command( + epilog="""Examples: + +\b + 1. To install a dataset, set up its data directory ("datadir"). For + example, to setup access to Montgomery files you downloaded locally at + the directory "/path/to/montgomery/files", edit the RC file (typically + ``$HOME/.config/ptbench.toml``), and add a line like the following: + + .. code:: toml + + [datadir] + montgomery = "/path/to/montgomery/files" + + .. note:: + + This setting **is** case-sensitive. + +\b + 2. List all raw datasets supported (and configured): + + .. code:: sh + + $ ptbench dataset list + +""", +) +@verbosity_option(logger=logger, expose_value=False) +def list(): + """Lists all supported and configured datasets.""" + installed = _get_installed_protocols() + + click.echo("Available protocols:") + for k, v in installed.items(): + click.echo(f'- {k}: "{v}"') + + +@datamodule.command( + epilog="""Examples: + + 1. Check if all files of the Montgomery dataset can be loaded: + + .. code:: sh + + ptbench dataset check -vv shenzhen + + 2. Check if all files of multiple installed datasets can be loaded: + + .. code:: sh + + ptbench dataset check -vv montgomery shenzhen + +""", +) +@click.argument( + "protocols", + nargs=-1, +) +@click.option( + "--limit", + "-l", + help="Limit check to the first N samples in each dataset, making the " + "check sensibly faster. Set it to zero to check everything.", + required=True, + type=click.IntRange(0), + default=0, +) +@verbosity_option(logger=logger, expose_value=False) +def check(protocols, limit): + """Checks file access on one or more datamodules.""" + import importlib + + errors = 0 + for protocol in protocols: + try: + module = importlib.metadata.entry_points(group="ptbench.config")[ + protocol + ].module + except KeyError: + raise Exception(f"Could not find protocol {protocol}") + + datamodule = importlib.import_module(module).datamodule + + database_split = datamodule.database_split + raw_data_loader = datamodule.raw_data_loader + + errors += check_database_split_loading( + database_split, raw_data_loader, limit=limit + ) + + if not errors: + click.echo("No errors reported") diff --git a/src/ptbench/scripts/dataset.py b/src/ptbench/scripts/dataset.py deleted file mode 100644 index 844311b5..00000000 --- a/src/ptbench/scripts/dataset.py +++ /dev/null @@ -1,186 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from __future__ import annotations - -import importlib.metadata -import importlib.resources -import os - -import click - -from clapper.click import AliasedGroup, verbosity_option -from clapper.logging import setup - -from ..data.split import check_database_split_loading - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -def _get_supported_datasets(): - """Returns a list of supported dataset names.""" - basedir = importlib.resources.files(__name__.split(".", 1)[0]).joinpath( - "data" - ) - - retval = [] - for candidate in basedir.iterdir(): - if candidate.is_dir() and "__init__.py" in os.listdir(str(candidate)): - retval.append(candidate.name) - - return retval - - -def _get_installed_datasets() -> dict[str, str]: - """Returns a list of installed datasets as regular expressions. - - * group(0): the name of the key for the dataset directory - * group("name"): the short name for the dataset - """ - from ..utils.rc import load_rc - - return dict(load_rc().get("datadir", {})) - - -@click.group(cls=AliasedGroup) -def dataset() -> None: - """Commands for listing and verifying datasets.""" - pass - - -@dataset.command( - epilog="""Examples: - -\b - 1. To install a dataset, set up its data directory ("datadir"). For - example, to setup access to Montgomery files you downloaded locally at - the directory "/path/to/montgomery/files", edit the RC file (typically - ``$HOME/.config/ptbench.toml``), and add a line like the following: - - .. code:: toml - - [datadir] - montgomery = "/path/to/montgomery/files" - - .. note:: - - This setting **is** case-sensitive. - -\b - 2. List all raw datasets supported (and configured): - - .. code:: sh - - $ ptbench dataset list - -""", -) -@verbosity_option(logger=logger, expose_value=False) -def list(): - """Lists all supported and configured datasets.""" - supported = _get_supported_datasets() - installed = _get_installed_datasets() - - click.echo("Supported datasets:") - for k in supported: - if k in installed: - click.echo(f'- {k}: "{installed[k]}"') - else: - click.echo(f"* {k}: datadir.{k} (not set)") - - -@dataset.command( - epilog="""Examples: - - 1. Check if all files of the Montgomery dataset can be loaded: - - .. code:: sh - - ptbench dataset check -vv montgomery - - 2. Check if all files of multiple installed datasets can be loaded: - - .. code:: sh - - ptbench dataset check -vv montgomery shenzhen - - 3. Check if all files of all installed datasets can be loaded: - - .. code:: sh - - ptbench dataset check - -""", -) -@click.argument( - "dataset", - nargs=-1, -) -@click.option( - "--limit", - "-l", - help="Limit check to the first N samples in each dataset, making the " - "check sensibly faster. Set it to zero to check everything.", - required=True, - type=click.IntRange(0), - default=0, -) -@verbosity_option(logger=logger, expose_value=False) -def check(dataset, limit): - """Checks file access on one or more datasets.""" - import importlib - - to_check = _get_installed_datasets() - - if dataset: # check only some - delete = [k for k in to_check.keys() if k not in dataset] - for k in delete: - del to_check[k] - - if not to_check: - click.secho( - "WARNING: No configured datasets matching specifications", - fg="yellow", - bold=True, - ) - click.echo( - "Try ptbench dataset list --help to get help in " - "configuring a dataset" - ) - else: - errors = 0 - for k in to_check.keys(): - click.echo(f'Checking "{k}" dataset...') - - # Gathering protocols for the dataset - entrypoints = [ - i - for i in importlib.metadata.entry_points( - group="ptbench.config" - ).names - if i == k - ] - protocols_modules = sorted( - [ - importlib.metadata.entry_points(group="ptbench.config")[ - i - ].module - for i in entrypoints - ] - ) - - for protocol in protocols_modules: - datamodule = importlib.import_module(protocol).datamodule - - database_split = datamodule.database_split - raw_data_loader = datamodule.raw_data_loader - - logger.info(f"Checking protocol {protocol}") - - errors += check_database_split_loading( - database_split._subsets, raw_data_loader, limit=limit - ) - - if not errors: - click.echo("No errors reported") -- GitLab