Skip to content
Snippets Groups Projects
Commit 1b311779 authored by André Anjos's avatar André Anjos :speech_balloon: Committed by Daniel CARRON
Browse files

[scripts.saliency] Re-structure CLI to make it more consistent with submodule organisation

parent 910e58c1
No related branches found
No related tags found
1 merge request!12Adds grad-cam support on classifiers
...@@ -2,25 +2,12 @@ ...@@ -2,25 +2,12 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import importlib
import click import click
from clapper.click import AliasedGroup from clapper.click import AliasedGroup
from . import (
config,
database,
evaluate,
evaluate_saliencymaps,
experiment,
generate_saliencymaps,
predict,
saliency_completeness,
saliency_interpretability,
train,
train_analysis,
view_saliency,
)
@click.group( @click.group(
cls=AliasedGroup, cls=AliasedGroup,
...@@ -31,15 +18,50 @@ def cli(): ...@@ -31,15 +18,50 @@ def cli():
pass pass
cli.add_command(config.config) cli.add_command(importlib.import_module("..config", package=__name__).config)
cli.add_command(database.database) cli.add_command(
cli.add_command(evaluate.evaluate) importlib.import_module("..database", package=__name__).database
cli.add_command(saliency_completeness.saliency_completeness) )
cli.add_command(saliency_interpretability.saliency_interpretability) cli.add_command(
cli.add_command(evaluate_saliencymaps.evaluate_saliencymaps) importlib.import_module("..evaluate", package=__name__).evaluate
cli.add_command(experiment.experiment) )
cli.add_command(generate_saliencymaps.generate_saliencymaps) cli.add_command(
cli.add_command(predict.predict) importlib.import_module("..experiment", package=__name__).experiment
cli.add_command(train.train) )
cli.add_command(train_analysis.train_analysis) cli.add_command(importlib.import_module("..predict", package=__name__).predict)
cli.add_command(view_saliency.view_saliency) cli.add_command(importlib.import_module("..train", package=__name__).train)
cli.add_command(
importlib.import_module("..train_analysis", package=__name__).train_analysis
)
@click.group(
cls=AliasedGroup,
context_settings=dict(help_option_names=["-?", "-h", "--help"]),
)
def saliency():
"""Sub-commands to generate, evaluate and view saliency maps."""
pass
cli.add_command(saliency)
saliency.add_command(
importlib.import_module("..saliency.generate", package=__name__).generate
)
saliency.add_command(
importlib.import_module(
"..saliency.completeness", package=__name__
).completeness
)
saliency.add_command(
importlib.import_module(
"..saliency.interpretability", package=__name__
).interpretability
)
saliency.add_command(
importlib.import_module("..saliency.evaluate", package=__name__).evaluate
)
saliency.add_command(
importlib.import_module("..saliency.view", package=__name__).view
)
...@@ -10,8 +10,8 @@ import click ...@@ -10,8 +10,8 @@ import click
from clapper.click import ResourceOption, verbosity_option from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from ..models.typing import SaliencyMapAlgorithm from ...models.typing import SaliencyMapAlgorithm
from .click import ConfigCommand from ..click import ConfigCommand
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
...@@ -25,7 +25,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -25,7 +25,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
.. code:: sh .. code:: sh
ptbench saliency-completeness -vv pasa tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/completeness-scores/ ptbench saliency completeness -vv pasa tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/completeness-scores/
""", """,
) )
...@@ -154,7 +154,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -154,7 +154,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ResourceOption, cls=ResourceOption,
) )
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def saliency_completeness( def completeness(
model, model,
datamodule, datamodule,
output_folder, output_folder,
...@@ -199,8 +199,8 @@ def saliency_completeness( ...@@ -199,8 +199,8 @@ def saliency_completeness(
""" """
import json import json
from ..engine.device import DeviceManager from ...engine.device import DeviceManager
from ..engine.saliency.completeness import run from ...engine.saliency.completeness import run
logger.info(f"Output folder: {output_folder}") logger.info(f"Output folder: {output_folder}")
output_folder.mkdir(parents=True, exist_ok=True) output_folder.mkdir(parents=True, exist_ok=True)
......
...@@ -10,8 +10,8 @@ import click ...@@ -10,8 +10,8 @@ import click
from clapper.click import ResourceOption, verbosity_option from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from ..models.typing import SaliencyMapAlgorithm from ...models.typing import SaliencyMapAlgorithm
from .click import ConfigCommand from ..click import ConfigCommand
# avoids X11/graphical desktop requirement when creating plots # avoids X11/graphical desktop requirement when creating plots
__import__("matplotlib").use("agg") __import__("matplotlib").use("agg")
...@@ -24,12 +24,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -24,12 +24,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ConfigCommand, cls=ConfigCommand,
epilog="""Examples: epilog="""Examples:
\b 1. Tabulates and generates plots for two saliency map algorithms:
1. Tabulates and generates plots for two saliency map algorithms:
.. code:: sh .. code:: sh
ptbench evaluate-saliencymaps -vv -e gradcam path/to/gradcam-completeness.json path/to/gradcam-interpretability.json -e gradcam++ path/to/gradcam++-completeness.json path/to/gradcam++-interpretability.json ptbench saliency evaluate -vv -e gradcam path/to/gradcam-completeness.json path/to/gradcam-interpretability.json -e gradcam++ path/to/gradcam++-completeness.json path/to/gradcam++-interpretability.json
""", """,
) )
@click.option( @click.option(
...@@ -73,7 +72,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -73,7 +72,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ResourceOption, cls=ResourceOption,
) )
@verbosity_option(logger=logger, expose_value=False) @verbosity_option(logger=logger, expose_value=False)
def evaluate_saliencymaps( def evaluate(
entry, entry,
output_folder, output_folder,
**_, # ignored **_, # ignored
...@@ -83,7 +82,7 @@ def evaluate_saliencymaps( ...@@ -83,7 +82,7 @@ def evaluate_saliencymaps(
from matplotlib.backends.backend_pdf import PdfPages from matplotlib.backends.backend_pdf import PdfPages
from ..engine.saliency.evaluator import run, summary_table from ...engine.saliency.evaluator import run, summary_table
summary = { summary = {
algo: run(algo, json.load(complet.open()), json.load(interp.open())) algo: run(algo, json.load(complet.open()), json.load(interp.open()))
......
...@@ -10,8 +10,8 @@ import click ...@@ -10,8 +10,8 @@ import click
from clapper.click import ResourceOption, verbosity_option from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from ..models.typing import SaliencyMapAlgorithm from ...models.typing import SaliencyMapAlgorithm
from .click import ConfigCommand from ..click import ConfigCommand
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
...@@ -27,7 +27,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -27,7 +27,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
.. code:: sh .. code:: sh
ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/output ptbench saliency generate -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/output
""", """,
) )
...@@ -143,7 +143,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -143,7 +143,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ResourceOption, cls=ResourceOption,
) )
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def generate_saliencymaps( def generate(
model, model,
datamodule, datamodule,
output_folder, output_folder,
...@@ -163,8 +163,8 @@ def generate_saliencymaps( ...@@ -163,8 +163,8 @@ def generate_saliencymaps(
algorithm and trained model. algorithm and trained model.
""" """
from ..engine.device import DeviceManager from ...engine.device import DeviceManager
from ..engine.saliency.generator import run from ...engine.saliency.generator import run
logger.info(f"Output folder: {output_folder}") logger.info(f"Output folder: {output_folder}")
output_folder.mkdir(parents=True, exist_ok=True) output_folder.mkdir(parents=True, exist_ok=True)
......
...@@ -9,7 +9,7 @@ import click ...@@ -9,7 +9,7 @@ import click
from clapper.click import ResourceOption, verbosity_option from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from .click import ConfigCommand from ..click import ConfigCommand
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
...@@ -23,7 +23,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -23,7 +23,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
.. code:: sh .. code:: sh
ptbench saliency-interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-json=parent_folder/gradcam/tbx11k-v1-interp.json ptbench saliency interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-json=parent_folder/gradcam/tbx11k-v1-interp.json
""", """,
) )
...@@ -78,7 +78,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -78,7 +78,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ResourceOption, cls=ResourceOption,
) )
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def saliency_interpretability( def interpretability(
datamodule, datamodule,
input_folder, input_folder,
target_label, target_label,
...@@ -134,7 +134,7 @@ def saliency_interpretability( ...@@ -134,7 +134,7 @@ def saliency_interpretability(
import json import json
from ..engine.saliency.interpretability import run from ...engine.saliency.interpretability import run
datamodule.model_transforms = [] datamodule.model_transforms = []
datamodule.prepare_data() datamodule.prepare_data()
......
...@@ -18,13 +18,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -18,13 +18,11 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ConfigCommand, cls=ConfigCommand,
epilog="""Examples: epilog="""Examples:
\b 1. Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration:
1. Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration:
.. code:: sh .. code:: sh
ptbench visualize -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-folder=path/to/visualizations
ptbench saliency view -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-folder=path/to/visualizations
""", """,
) )
@click.option( @click.option(
...@@ -91,7 +89,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -91,7 +89,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ResourceOption, cls=ResourceOption,
) )
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def view_saliency( def view(
model, model,
datamodule, datamodule,
input_folder, input_folder,
...@@ -102,8 +100,8 @@ def view_saliency( ...@@ -102,8 +100,8 @@ def view_saliency(
) -> None: ) -> None:
"""Generates heatmaps for input CXRs based on existing saliency maps.""" """Generates heatmaps for input CXRs based on existing saliency maps."""
from ..engine.saliency.viewer import run from ...engine.saliency.viewer import run
from .utils import save_sh_command from ..utils import save_sh_command
assert ( assert (
input_folder != output_folder input_folder != output_folder
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment