From 1b311779933452b71afb4d5e405ace3677ace094 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Fri, 15 Dec 2023 14:45:07 +0100
Subject: [PATCH] [scripts.saliency] Re-structure CLI to make it more
 consistent with submodule organisation

---
 src/ptbench/scripts/cli.py                    | 76 ++++++++++++-------
 src/ptbench/scripts/saliency/__init__.py      |  0
 .../completeness.py}                          | 12 +--
 .../evaluate.py}                              | 15 ++--
 .../generate.py}                              | 12 +--
 .../interpretability.py}                      |  8 +-
 .../{view_saliency.py => saliency/view.py}    | 14 ++--
 7 files changed, 78 insertions(+), 59 deletions(-)
 create mode 100644 src/ptbench/scripts/saliency/__init__.py
 rename src/ptbench/scripts/{saliency_completeness.py => saliency/completeness.py} (96%)
 rename src/ptbench/scripts/{evaluate_saliencymaps.py => saliency/evaluate.py} (86%)
 rename src/ptbench/scripts/{generate_saliencymaps.py => saliency/generate.py} (93%)
 rename src/ptbench/scripts/{saliency_interpretability.py => saliency/interpretability.py} (96%)
 rename src/ptbench/scripts/{view_saliency.py => saliency/view.py} (92%)

diff --git a/src/ptbench/scripts/cli.py b/src/ptbench/scripts/cli.py
index cd7ff093..e8f7ba24 100644
--- a/src/ptbench/scripts/cli.py
+++ b/src/ptbench/scripts/cli.py
@@ -2,25 +2,12 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import importlib
+
 import click
 
 from clapper.click import AliasedGroup
 
-from . import (
-    config,
-    database,
-    evaluate,
-    evaluate_saliencymaps,
-    experiment,
-    generate_saliencymaps,
-    predict,
-    saliency_completeness,
-    saliency_interpretability,
-    train,
-    train_analysis,
-    view_saliency,
-)
-
 
 @click.group(
     cls=AliasedGroup,
@@ -31,15 +18,50 @@ def cli():
     pass
 
 
-cli.add_command(config.config)
-cli.add_command(database.database)
-cli.add_command(evaluate.evaluate)
-cli.add_command(saliency_completeness.saliency_completeness)
-cli.add_command(saliency_interpretability.saliency_interpretability)
-cli.add_command(evaluate_saliencymaps.evaluate_saliencymaps)
-cli.add_command(experiment.experiment)
-cli.add_command(generate_saliencymaps.generate_saliencymaps)
-cli.add_command(predict.predict)
-cli.add_command(train.train)
-cli.add_command(train_analysis.train_analysis)
-cli.add_command(view_saliency.view_saliency)
+cli.add_command(importlib.import_module("..config", package=__name__).config)
+cli.add_command(
+    importlib.import_module("..database", package=__name__).database
+)
+cli.add_command(
+    importlib.import_module("..evaluate", package=__name__).evaluate
+)
+cli.add_command(
+    importlib.import_module("..experiment", package=__name__).experiment
+)
+cli.add_command(importlib.import_module("..predict", package=__name__).predict)
+cli.add_command(importlib.import_module("..train", package=__name__).train)
+cli.add_command(
+    importlib.import_module("..train_analysis", package=__name__).train_analysis
+)
+
+
+@click.group(
+    cls=AliasedGroup,
+    context_settings=dict(help_option_names=["-?", "-h", "--help"]),
+)
+def saliency():
+    """Sub-commands to generate, evaluate and view saliency maps."""
+    pass
+
+
+cli.add_command(saliency)
+
+saliency.add_command(
+    importlib.import_module("..saliency.generate", package=__name__).generate
+)
+saliency.add_command(
+    importlib.import_module(
+        "..saliency.completeness", package=__name__
+    ).completeness
+)
+saliency.add_command(
+    importlib.import_module(
+        "..saliency.interpretability", package=__name__
+    ).interpretability
+)
+saliency.add_command(
+    importlib.import_module("..saliency.evaluate", package=__name__).evaluate
+)
+saliency.add_command(
+    importlib.import_module("..saliency.view", package=__name__).view
+)
diff --git a/src/ptbench/scripts/saliency/__init__.py b/src/ptbench/scripts/saliency/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/ptbench/scripts/saliency_completeness.py b/src/ptbench/scripts/saliency/completeness.py
similarity index 96%
rename from src/ptbench/scripts/saliency_completeness.py
rename to src/ptbench/scripts/saliency/completeness.py
index dab343d4..29c7599e 100644
--- a/src/ptbench/scripts/saliency_completeness.py
+++ b/src/ptbench/scripts/saliency/completeness.py
@@ -10,8 +10,8 @@ import click
 from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 
-from ..models.typing import SaliencyMapAlgorithm
-from .click import ConfigCommand
+from ...models.typing import SaliencyMapAlgorithm
+from ..click import ConfigCommand
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
@@ -25,7 +25,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
    .. code:: sh
 
-      ptbench saliency-completeness -vv pasa tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/completeness-scores/
+      ptbench saliency completeness -vv pasa tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/completeness-scores/
 
 """,
 )
@@ -154,7 +154,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ResourceOption,
 )
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
-def saliency_completeness(
+def completeness(
     model,
     datamodule,
     output_folder,
@@ -199,8 +199,8 @@ def saliency_completeness(
     """
     import json
 
-    from ..engine.device import DeviceManager
-    from ..engine.saliency.completeness import run
+    from ...engine.device import DeviceManager
+    from ...engine.saliency.completeness import run
 
     logger.info(f"Output folder: {output_folder}")
     output_folder.mkdir(parents=True, exist_ok=True)
diff --git a/src/ptbench/scripts/evaluate_saliencymaps.py b/src/ptbench/scripts/saliency/evaluate.py
similarity index 86%
rename from src/ptbench/scripts/evaluate_saliencymaps.py
rename to src/ptbench/scripts/saliency/evaluate.py
index 5b96b37a..cb3acc47 100644
--- a/src/ptbench/scripts/evaluate_saliencymaps.py
+++ b/src/ptbench/scripts/saliency/evaluate.py
@@ -10,8 +10,8 @@ import click
 from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 
-from ..models.typing import SaliencyMapAlgorithm
-from .click import ConfigCommand
+from ...models.typing import SaliencyMapAlgorithm
+from ..click import ConfigCommand
 
 # avoids X11/graphical desktop requirement when creating plots
 __import__("matplotlib").use("agg")
@@ -24,12 +24,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ConfigCommand,
     epilog="""Examples:
 
-\b
-    1. Tabulates and generates plots for two saliency map algorithms:
+1. Tabulates and generates plots for two saliency map algorithms:
 
-       .. code:: sh
+   .. code:: sh
 
-          ptbench evaluate-saliencymaps -vv -e gradcam path/to/gradcam-completeness.json path/to/gradcam-interpretability.json -e gradcam++ path/to/gradcam++-completeness.json path/to/gradcam++-interpretability.json
+      ptbench saliency evaluate -vv -e gradcam path/to/gradcam-completeness.json path/to/gradcam-interpretability.json -e gradcam++ path/to/gradcam++-completeness.json path/to/gradcam++-interpretability.json
 """,
 )
 @click.option(
@@ -73,7 +72,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ResourceOption,
 )
 @verbosity_option(logger=logger, expose_value=False)
-def evaluate_saliencymaps(
+def evaluate(
     entry,
     output_folder,
     **_,  # ignored
@@ -83,7 +82,7 @@ def evaluate_saliencymaps(
 
     from matplotlib.backends.backend_pdf import PdfPages
 
-    from ..engine.saliency.evaluator import run, summary_table
+    from ...engine.saliency.evaluator import run, summary_table
 
     summary = {
         algo: run(algo, json.load(complet.open()), json.load(interp.open()))
diff --git a/src/ptbench/scripts/generate_saliencymaps.py b/src/ptbench/scripts/saliency/generate.py
similarity index 93%
rename from src/ptbench/scripts/generate_saliencymaps.py
rename to src/ptbench/scripts/saliency/generate.py
index 3277db96..6ccb9f5e 100644
--- a/src/ptbench/scripts/generate_saliencymaps.py
+++ b/src/ptbench/scripts/saliency/generate.py
@@ -10,8 +10,8 @@ import click
 from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 
-from ..models.typing import SaliencyMapAlgorithm
-from .click import ConfigCommand
+from ...models.typing import SaliencyMapAlgorithm
+from ..click import ConfigCommand
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
@@ -27,7 +27,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
    .. code:: sh
 
-      ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/output
+      ptbench saliency generate -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/output
 
 """,
 )
@@ -143,7 +143,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ResourceOption,
 )
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
-def generate_saliencymaps(
+def generate(
     model,
     datamodule,
     output_folder,
@@ -163,8 +163,8 @@ def generate_saliencymaps(
     algorithm and trained model.
     """
 
-    from ..engine.device import DeviceManager
-    from ..engine.saliency.generator import run
+    from ...engine.device import DeviceManager
+    from ...engine.saliency.generator import run
 
     logger.info(f"Output folder: {output_folder}")
     output_folder.mkdir(parents=True, exist_ok=True)
diff --git a/src/ptbench/scripts/saliency_interpretability.py b/src/ptbench/scripts/saliency/interpretability.py
similarity index 96%
rename from src/ptbench/scripts/saliency_interpretability.py
rename to src/ptbench/scripts/saliency/interpretability.py
index bca56565..155f81db 100644
--- a/src/ptbench/scripts/saliency_interpretability.py
+++ b/src/ptbench/scripts/saliency/interpretability.py
@@ -9,7 +9,7 @@ import click
 from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 
-from .click import ConfigCommand
+from ..click import ConfigCommand
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
@@ -23,7 +23,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
    .. code:: sh
 
-      ptbench saliency-interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-json=parent_folder/gradcam/tbx11k-v1-interp.json
+      ptbench saliency interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-json=parent_folder/gradcam/tbx11k-v1-interp.json
 
 """,
 )
@@ -78,7 +78,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ResourceOption,
 )
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
-def saliency_interpretability(
+def interpretability(
     datamodule,
     input_folder,
     target_label,
@@ -134,7 +134,7 @@ def saliency_interpretability(
 
     import json
 
-    from ..engine.saliency.interpretability import run
+    from ...engine.saliency.interpretability import run
 
     datamodule.model_transforms = []
     datamodule.prepare_data()
diff --git a/src/ptbench/scripts/view_saliency.py b/src/ptbench/scripts/saliency/view.py
similarity index 92%
rename from src/ptbench/scripts/view_saliency.py
rename to src/ptbench/scripts/saliency/view.py
index c3b9379f..028b1417 100644
--- a/src/ptbench/scripts/view_saliency.py
+++ b/src/ptbench/scripts/saliency/view.py
@@ -18,13 +18,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ConfigCommand,
     epilog="""Examples:
 
-\b
-    1. Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration:
+1. Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration:
 
-       .. code:: sh
-
-          ptbench visualize -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-folder=path/to/visualizations
+   .. code:: sh
 
+      ptbench saliency view -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-folder=path/to/visualizations
 """,
 )
 @click.option(
@@ -91,7 +89,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ResourceOption,
 )
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
-def view_saliency(
+def view(
     model,
     datamodule,
     input_folder,
@@ -102,8 +100,8 @@ def view_saliency(
 ) -> None:
     """Generates heatmaps for input CXRs based on existing saliency maps."""
 
-    from ..engine.saliency.viewer import run
-    from .utils import save_sh_command
+    from ...engine.saliency.viewer import run
+    from ..utils import save_sh_command
 
     assert (
         input_folder != output_folder
-- 
GitLab