Skip to content
Snippets Groups Projects
Commit 6d819c3d authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Rename dataset script to datamodule and update

parent 76144516
No related branches found
No related tags found
No related merge requests found
...@@ -10,8 +10,9 @@ from . import ( ...@@ -10,8 +10,9 @@ from . import (
aggregpred, aggregpred,
compare, compare,
config, config,
dataset, datamodule,
evaluate, evaluate,
experiment,
predict, predict,
predtojson, predtojson,
train, train,
...@@ -31,8 +32,9 @@ def cli(): ...@@ -31,8 +32,9 @@ def cli():
cli.add_command(aggregpred.aggregpred) cli.add_command(aggregpred.aggregpred)
cli.add_command(compare.compare) cli.add_command(compare.compare)
cli.add_command(config.config) cli.add_command(config.config)
cli.add_command(dataset.dataset) cli.add_command(datamodule.datamodule)
cli.add_command(evaluate.evaluate) cli.add_command(evaluate.evaluate)
cli.add_command(experiment.experiment)
cli.add_command(predict.predict) cli.add_command(predict.predict)
cli.add_command(predtojson.predtojson) cli.add_command(predtojson.predtojson)
cli.add_command(train.train) cli.add_command(train.train)
......
...@@ -6,7 +6,6 @@ from __future__ import annotations ...@@ -6,7 +6,6 @@ from __future__ import annotations
import importlib.metadata import importlib.metadata
import importlib.resources import importlib.resources
import os
import click import click
...@@ -18,38 +17,45 @@ from ..data.split import check_database_split_loading ...@@ -18,38 +17,45 @@ from ..data.split import check_database_split_loading
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _get_supported_datasets(): def _get_installed_protocols() -> dict[str, str]:
"""Returns a list of supported dataset names.""" """Returns a list of all installed protocols.
basedir = importlib.resources.files(__name__.split(".", 1)[0]).joinpath(
"data"
)
retval = []
for candidate in basedir.iterdir():
if candidate.is_dir() and "__init__.py" in os.listdir(str(candidate)):
retval.append(candidate.name)
return retval Returns
-------
protocols:
List of protocols.
"""
entrypoints = sorted(
[
entrypoint
for entrypoint in importlib.metadata.entry_points(
group="ptbench.config"
).names
]
)
def _get_installed_datasets() -> dict[str, str]: protocols = [
"""Returns a list of installed datasets as regular expressions. importlib.metadata.entry_points(group="ptbench.config")[
entrypoint
].module
for entrypoint in entrypoints
]
* group(0): the name of the key for the dataset directory protocols_dict = {
* group("name"): the short name for the dataset entrypoints[i]: protocols[i] for i in range(len(entrypoints))
""" }
from ..utils.rc import load_rc
return dict(load_rc().get("datadir", {})) return protocols_dict
@click.group(cls=AliasedGroup) @click.group(cls=AliasedGroup)
def dataset() -> None: def datamodule() -> None:
"""Commands for listing and verifying datasets.""" """Commands for listing and verifying datamodules."""
pass pass
@dataset.command( @datamodule.command(
epilog="""Examples: epilog="""Examples:
\b \b
...@@ -79,25 +85,21 @@ def dataset() -> None: ...@@ -79,25 +85,21 @@ def dataset() -> None:
@verbosity_option(logger=logger, expose_value=False) @verbosity_option(logger=logger, expose_value=False)
def list(): def list():
"""Lists all supported and configured datasets.""" """Lists all supported and configured datasets."""
supported = _get_supported_datasets() installed = _get_installed_protocols()
installed = _get_installed_datasets()
click.echo("Supported datasets:") click.echo("Available protocols:")
for k in supported: for k, v in installed.items():
if k in installed: click.echo(f'- {k}: "{v}"')
click.echo(f'- {k}: "{installed[k]}"')
else:
click.echo(f"* {k}: datadir.{k} (not set)")
@dataset.command( @datamodule.command(
epilog="""Examples: epilog="""Examples:
1. Check if all files of the Montgomery dataset can be loaded: 1. Check if all files of the Montgomery dataset can be loaded:
.. code:: sh .. code:: sh
ptbench dataset check -vv montgomery ptbench dataset check -vv shenzhen
2. Check if all files of multiple installed datasets can be loaded: 2. Check if all files of multiple installed datasets can be loaded:
...@@ -105,16 +107,10 @@ def list(): ...@@ -105,16 +107,10 @@ def list():
ptbench dataset check -vv montgomery shenzhen ptbench dataset check -vv montgomery shenzhen
3. Check if all files of all installed datasets can be loaded:
.. code:: sh
ptbench dataset check
""", """,
) )
@click.argument( @click.argument(
"dataset", "protocols",
nargs=-1, nargs=-1,
) )
@click.option( @click.option(
...@@ -127,60 +123,27 @@ def list(): ...@@ -127,60 +123,27 @@ def list():
default=0, default=0,
) )
@verbosity_option(logger=logger, expose_value=False) @verbosity_option(logger=logger, expose_value=False)
def check(dataset, limit): def check(protocols, limit):
"""Checks file access on one or more datasets.""" """Checks file access on one or more datamodules."""
import importlib import importlib
to_check = _get_installed_datasets() errors = 0
for protocol in protocols:
try:
module = importlib.metadata.entry_points(group="ptbench.config")[
protocol
].module
except KeyError:
raise Exception(f"Could not find protocol {protocol}")
if dataset: # check only some datamodule = importlib.import_module(module).datamodule
delete = [k for k in to_check.keys() if k not in dataset]
for k in delete:
del to_check[k]
if not to_check: database_split = datamodule.database_split
click.secho( raw_data_loader = datamodule.raw_data_loader
"WARNING: No configured datasets matching specifications",
fg="yellow", errors += check_database_split_loading(
bold=True, database_split, raw_data_loader, limit=limit
)
click.echo(
"Try ptbench dataset list --help to get help in "
"configuring a dataset"
) )
else:
errors = 0 if not errors:
for k in to_check.keys(): click.echo("No errors reported")
click.echo(f'Checking "{k}" dataset...')
# Gathering protocols for the dataset
entrypoints = [
i
for i in importlib.metadata.entry_points(
group="ptbench.config"
).names
if i == k
]
protocols_modules = sorted(
[
importlib.metadata.entry_points(group="ptbench.config")[
i
].module
for i in entrypoints
]
)
for protocol in protocols_modules:
datamodule = importlib.import_module(protocol).datamodule
database_split = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader
logger.info(f"Checking protocol {protocol}")
errors += check_database_split_loading(
database_split._subsets, raw_data_loader, limit=limit
)
if not errors:
click.echo("No errors reported")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment