From 1b311779933452b71afb4d5e405ace3677ace094 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Fri, 15 Dec 2023 14:45:07 +0100 Subject: [PATCH] [scripts.saliency] Re-structure CLI to make it more consistent with submodule organisation --- src/ptbench/scripts/cli.py | 76 ++++++++++++------- src/ptbench/scripts/saliency/__init__.py | 0 .../completeness.py} | 12 +-- .../evaluate.py} | 15 ++-- .../generate.py} | 12 +-- .../interpretability.py} | 8 +- .../{view_saliency.py => saliency/view.py} | 14 ++-- 7 files changed, 78 insertions(+), 59 deletions(-) create mode 100644 src/ptbench/scripts/saliency/__init__.py rename src/ptbench/scripts/{saliency_completeness.py => saliency/completeness.py} (96%) rename src/ptbench/scripts/{evaluate_saliencymaps.py => saliency/evaluate.py} (86%) rename src/ptbench/scripts/{generate_saliencymaps.py => saliency/generate.py} (93%) rename src/ptbench/scripts/{saliency_interpretability.py => saliency/interpretability.py} (96%) rename src/ptbench/scripts/{view_saliency.py => saliency/view.py} (92%) diff --git a/src/ptbench/scripts/cli.py b/src/ptbench/scripts/cli.py index cd7ff093..e8f7ba24 100644 --- a/src/ptbench/scripts/cli.py +++ b/src/ptbench/scripts/cli.py @@ -2,25 +2,12 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import importlib + import click from clapper.click import AliasedGroup -from . import ( - config, - database, - evaluate, - evaluate_saliencymaps, - experiment, - generate_saliencymaps, - predict, - saliency_completeness, - saliency_interpretability, - train, - train_analysis, - view_saliency, -) - @click.group( cls=AliasedGroup, @@ -31,15 +18,50 @@ def cli(): pass -cli.add_command(config.config) -cli.add_command(database.database) -cli.add_command(evaluate.evaluate) -cli.add_command(saliency_completeness.saliency_completeness) -cli.add_command(saliency_interpretability.saliency_interpretability) -cli.add_command(evaluate_saliencymaps.evaluate_saliencymaps) -cli.add_command(experiment.experiment) -cli.add_command(generate_saliencymaps.generate_saliencymaps) -cli.add_command(predict.predict) -cli.add_command(train.train) -cli.add_command(train_analysis.train_analysis) -cli.add_command(view_saliency.view_saliency) +cli.add_command(importlib.import_module("..config", package=__name__).config) +cli.add_command( + importlib.import_module("..database", package=__name__).database +) +cli.add_command( + importlib.import_module("..evaluate", package=__name__).evaluate +) +cli.add_command( + importlib.import_module("..experiment", package=__name__).experiment +) +cli.add_command(importlib.import_module("..predict", package=__name__).predict) +cli.add_command(importlib.import_module("..train", package=__name__).train) +cli.add_command( + importlib.import_module("..train_analysis", package=__name__).train_analysis +) + + +@click.group( + cls=AliasedGroup, + context_settings=dict(help_option_names=["-?", "-h", "--help"]), +) +def saliency(): + """Sub-commands to generate, evaluate and view saliency maps.""" + pass + + +cli.add_command(saliency) + +saliency.add_command( + importlib.import_module("..saliency.generate", package=__name__).generate +) +saliency.add_command( + importlib.import_module( + "..saliency.completeness", package=__name__ + ).completeness +) +saliency.add_command( + importlib.import_module( + "..saliency.interpretability", package=__name__ + ).interpretability +) +saliency.add_command( + importlib.import_module("..saliency.evaluate", package=__name__).evaluate +) +saliency.add_command( + importlib.import_module("..saliency.view", package=__name__).view +) diff --git a/src/ptbench/scripts/saliency/__init__.py b/src/ptbench/scripts/saliency/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ptbench/scripts/saliency_completeness.py b/src/ptbench/scripts/saliency/completeness.py similarity index 96% rename from src/ptbench/scripts/saliency_completeness.py rename to src/ptbench/scripts/saliency/completeness.py index dab343d4..29c7599e 100644 --- a/src/ptbench/scripts/saliency_completeness.py +++ b/src/ptbench/scripts/saliency/completeness.py @@ -10,8 +10,8 @@ import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup -from ..models.typing import SaliencyMapAlgorithm -from .click import ConfigCommand +from ...models.typing import SaliencyMapAlgorithm +from ..click import ConfigCommand logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -25,7 +25,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") .. code:: sh - ptbench saliency-completeness -vv pasa tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/completeness-scores/ + ptbench saliency completeness -vv pasa tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/completeness-scores/ """, ) @@ -154,7 +154,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) -def saliency_completeness( +def completeness( model, datamodule, output_folder, @@ -199,8 +199,8 @@ def saliency_completeness( """ import json - from ..engine.device import DeviceManager - from ..engine.saliency.completeness import run + from ...engine.device import DeviceManager + from ...engine.saliency.completeness import run logger.info(f"Output folder: {output_folder}") output_folder.mkdir(parents=True, exist_ok=True) diff --git a/src/ptbench/scripts/evaluate_saliencymaps.py b/src/ptbench/scripts/saliency/evaluate.py similarity index 86% rename from src/ptbench/scripts/evaluate_saliencymaps.py rename to src/ptbench/scripts/saliency/evaluate.py index 5b96b37a..cb3acc47 100644 --- a/src/ptbench/scripts/evaluate_saliencymaps.py +++ b/src/ptbench/scripts/saliency/evaluate.py @@ -10,8 +10,8 @@ import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup -from ..models.typing import SaliencyMapAlgorithm -from .click import ConfigCommand +from ...models.typing import SaliencyMapAlgorithm +from ..click import ConfigCommand # avoids X11/graphical desktop requirement when creating plots __import__("matplotlib").use("agg") @@ -24,12 +24,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ConfigCommand, epilog="""Examples: -\b - 1. Tabulates and generates plots for two saliency map algorithms: +1. Tabulates and generates plots for two saliency map algorithms: - .. code:: sh + .. code:: sh - ptbench evaluate-saliencymaps -vv -e gradcam path/to/gradcam-completeness.json path/to/gradcam-interpretability.json -e gradcam++ path/to/gradcam++-completeness.json path/to/gradcam++-interpretability.json + ptbench saliency evaluate -vv -e gradcam path/to/gradcam-completeness.json path/to/gradcam-interpretability.json -e gradcam++ path/to/gradcam++-completeness.json path/to/gradcam++-interpretability.json """, ) @click.option( @@ -73,7 +72,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @verbosity_option(logger=logger, expose_value=False) -def evaluate_saliencymaps( +def evaluate( entry, output_folder, **_, # ignored @@ -83,7 +82,7 @@ def evaluate_saliencymaps( from matplotlib.backends.backend_pdf import PdfPages - from ..engine.saliency.evaluator import run, summary_table + from ...engine.saliency.evaluator import run, summary_table summary = { algo: run(algo, json.load(complet.open()), json.load(interp.open())) diff --git a/src/ptbench/scripts/generate_saliencymaps.py b/src/ptbench/scripts/saliency/generate.py similarity index 93% rename from src/ptbench/scripts/generate_saliencymaps.py rename to src/ptbench/scripts/saliency/generate.py index 3277db96..6ccb9f5e 100644 --- a/src/ptbench/scripts/generate_saliencymaps.py +++ b/src/ptbench/scripts/saliency/generate.py @@ -10,8 +10,8 @@ import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup -from ..models.typing import SaliencyMapAlgorithm -from .click import ConfigCommand +from ...models.typing import SaliencyMapAlgorithm +from ..click import ConfigCommand logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -27,7 +27,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") .. code:: sh - ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/output + ptbench saliency generate -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/output """, ) @@ -143,7 +143,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) -def generate_saliencymaps( +def generate( model, datamodule, output_folder, @@ -163,8 +163,8 @@ def generate_saliencymaps( algorithm and trained model. """ - from ..engine.device import DeviceManager - from ..engine.saliency.generator import run + from ...engine.device import DeviceManager + from ...engine.saliency.generator import run logger.info(f"Output folder: {output_folder}") output_folder.mkdir(parents=True, exist_ok=True) diff --git a/src/ptbench/scripts/saliency_interpretability.py b/src/ptbench/scripts/saliency/interpretability.py similarity index 96% rename from src/ptbench/scripts/saliency_interpretability.py rename to src/ptbench/scripts/saliency/interpretability.py index bca56565..155f81db 100644 --- a/src/ptbench/scripts/saliency_interpretability.py +++ b/src/ptbench/scripts/saliency/interpretability.py @@ -9,7 +9,7 @@ import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup -from .click import ConfigCommand +from ..click import ConfigCommand logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -23,7 +23,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") .. code:: sh - ptbench saliency-interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-json=parent_folder/gradcam/tbx11k-v1-interp.json + ptbench saliency interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-json=parent_folder/gradcam/tbx11k-v1-interp.json """, ) @@ -78,7 +78,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) -def saliency_interpretability( +def interpretability( datamodule, input_folder, target_label, @@ -134,7 +134,7 @@ def saliency_interpretability( import json - from ..engine.saliency.interpretability import run + from ...engine.saliency.interpretability import run datamodule.model_transforms = [] datamodule.prepare_data() diff --git a/src/ptbench/scripts/view_saliency.py b/src/ptbench/scripts/saliency/view.py similarity index 92% rename from src/ptbench/scripts/view_saliency.py rename to src/ptbench/scripts/saliency/view.py index c3b9379f..028b1417 100644 --- a/src/ptbench/scripts/view_saliency.py +++ b/src/ptbench/scripts/saliency/view.py @@ -18,13 +18,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ConfigCommand, epilog="""Examples: -\b - 1. Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration: +1. Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration: - .. code:: sh - - ptbench visualize -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-folder=path/to/visualizations + .. code:: sh + ptbench saliency view -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-folder=path/to/visualizations """, ) @click.option( @@ -91,7 +89,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) -def view_saliency( +def view( model, datamodule, input_folder, @@ -102,8 +100,8 @@ def view_saliency( ) -> None: """Generates heatmaps for input CXRs based on existing saliency maps.""" - from ..engine.saliency.viewer import run - from .utils import save_sh_command + from ...engine.saliency.viewer import run + from ..utils import save_sh_command assert ( input_folder != output_folder -- GitLab