Skip to content
Snippets Groups Projects
Commit 128b7b70 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[saliencymap_generator] Save maps directly

parent d7a3c159
No related branches found
No related tags found
No related merge requests found
...@@ -12,14 +12,14 @@ import torch ...@@ -12,14 +12,14 @@ import torch
import torch.nn import torch.nn
import tqdm import tqdm
from ..models.typing import VisualisationType from ..models.typing import SaliencyMapAlgorithm
from .device import DeviceManager from .device import DeviceManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _create_visualiser( def _create_saliency_map_callable(
vis_type: VisualisationType, algo_type: SaliencyMapAlgorithm,
model: torch.nn.Module, model: torch.nn.Module,
target_layers: list[torch.nn.Module] | None, target_layers: list[torch.nn.Module] | None,
use_cuda: bool, use_cuda: bool,
...@@ -28,7 +28,7 @@ def _create_visualiser( ...@@ -28,7 +28,7 @@ def _create_visualiser(
import pytorch_grad_cam import pytorch_grad_cam
match vis_type: match algo_type:
case "gradcam": case "gradcam":
return pytorch_grad_cam.GradCAM( return pytorch_grad_cam.GradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda model=model, target_layers=target_layers, use_cuda=use_cuda
...@@ -82,45 +82,45 @@ def _create_visualiser( ...@@ -82,45 +82,45 @@ def _create_visualiser(
) )
case _: case _:
raise ValueError( raise ValueError(
f"Visualisation type `{vis_type}` is not currently supported." f"Saliency map algorithm `{algo_type}` is not currently "
f"supported."
) )
def _save_visualisation( def _save_saliency_map(
output_folder: pathlib.Path, name: str, vis: torch.Tensor output_folder: pathlib.Path, name: str, saliency_map: torch.Tensor
) -> None: ) -> None:
"""Helper function to save a saliency map to disk.""" """Helper function to save a saliency map to disk."""
n = pathlib.Path(name) n = pathlib.Path(name)
(output_folder / n.parent).mkdir(parents=True, exist_ok=True) (output_folder / n.parent).mkdir(parents=True, exist_ok=True)
numpy.save(output_folder / n.with_suffix(".npy"), vis) numpy.save(output_folder / n.with_suffix(".npy"), saliency_map[0])
def run( def run(
model: lightning.pytorch.LightningModule, model: lightning.pytorch.LightningModule,
datamodule: lightning.pytorch.LightningDataModule, datamodule: lightning.pytorch.LightningDataModule,
device_manager: DeviceManager, device_manager: DeviceManager,
visualisation_types: typing.Sequence[VisualisationType], saliency_map_algorithms: typing.Sequence[SaliencyMapAlgorithm],
target_class: typing.Literal["highest", "all"], target_class: typing.Literal["highest", "all"],
positive_only: bool, positive_only: bool,
output_folder: pathlib.Path, output_folder: pathlib.Path,
) -> None: ) -> None:
"""Applies visualisation techniques on input CXR, outputs pickled saliency """Applies saliency mapping techniques on input CXR, outputs pickled
maps directly to disk. saliency maps directly to disk.
Parameters Parameters
--------- ---------
model model
Neural network model (e.g. pasa). Neural network model (e.g. pasa).
datamodule datamodule
The lightning datamodule to use for training **and** validation The lightning datamodule to iterate on.
device_manager device_manager
An internal device representation, to be used for training and An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device validation. This representation can be converted into a pytorch device
or a torch lightning accelerator setup. or a torch lightning accelerator setup.
visualisation_types saliency_map_algorithms
The types of visualisations to generate for the current model and The algorithms for saliency map estimation to use.
datamodule.
target_class target_class
(Use only with multi-label models) Which class to target for CAM (Use only with multi-label models) Which class to target for CAM
calculation. Can be either set to "all" or "highest". "highest" is calculation. Can be either set to "all" or "highest". "highest" is
...@@ -131,7 +131,7 @@ def run( ...@@ -131,7 +131,7 @@ def run(
label == 1 in a binary classification task). This option is ignored on label == 1 in a binary classification task). This option is ignored on
a multi-class output model. a multi-class output model.
output_folder output_folder
Where to save all the visualisations (this path should exist before Where to save all the saliency maps (this path should exist before
this function is called) this function is called)
""" """
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
...@@ -140,9 +140,10 @@ def run( ...@@ -140,9 +140,10 @@ def run(
from ..models.pasa import Pasa from ..models.pasa import Pasa
if isinstance(model, Pasa): if isinstance(model, Pasa):
if "fullgrad" in visualisation_types: if "fullgrad" in saliency_map_algorithms:
raise ValueError( raise ValueError(
"Fullgrad visualisation is not supported for the Pasa model." "Fullgrad saliency map algorithm is not supported for the "
"Pasa model."
) )
target_layers = [model.fc14] # Last non-1x1 Conv2d layer target_layers = [model.fc14] # Last non-1x1 Conv2d layer
elif isinstance(model, Densenet): elif isinstance(model, Densenet):
...@@ -159,12 +160,17 @@ def run( ...@@ -159,12 +160,17 @@ def run(
model = model.to(device) model = model.to(device)
model.eval() model.eval()
for vis_type in visualisation_types: for algo_type in saliency_map_algorithms:
cam = _create_visualiser(vis_type, model, target_layers, use_cuda) # type: ignore saliency_map_callable = _create_saliency_map_callable(
algo_type,
model,
target_layers, # type: ignore
use_cuda,
)
for k, v in datamodule.predict_dataloader().items(): for k, v in datamodule.predict_dataloader().items():
logger.info( logger.info(
f"Generating saliency maps for dataset `{k}` via `{vis_type}`..." f"Generating saliency maps for dataset `{k}` via `{algo_type}`..."
) )
for sample in tqdm.tqdm( for sample in tqdm.tqdm(
...@@ -180,37 +186,42 @@ def run( ...@@ -180,37 +186,42 @@ def run(
if positive_only and (model.num_classes == 1) and (label == 0): if positive_only and (model.num_classes == 1) and (label == 0):
continue continue
# chooses target outputs to generate visualisations for # chooses target outputs to generate saliency maps for
if model.num_classes > 1: if model.num_classes > 1:
if target_class == "all": if target_class == "all":
# just blindly generate visualisations for all outputs # just blindly generate saliency maps for all outputs
# - make one directory for every target output and lay # - make one directory for every target output and lay
# images there like in the original dataset. # images there like in the original dataset.
for output_num in range(model.num_classes): for output_num in range(model.num_classes):
use_folder = ( use_folder = (
output_folder / vis_type / str(output_num) output_folder / algo_type / str(output_num)
) )
vis = cam( saliency_map = saliency_map_callable(
input_tensor=image, input_tensor=image,
targets=[ClassifierOutputTarget(output_num)], # type: ignore targets=[ClassifierOutputTarget(output_num)], # type: ignore
) )
_save_visualisation(use_folder, name, vis) # type: ignore _save_saliency_map(use_folder, name, saliency_map) # type: ignore
else: else:
# pytorch-grad-cam will evaluate the output with the # pytorch-grad-cam will evaluate the output with the
# highest value and produce a visualisation for it - we will # highest value and produce a saliency map for it - we
# save it to disk. # will save it to disk.
use_folder = output_folder / vis_type / "highest-output" use_folder = (
vis = cam(input_tensor=image, targets=None) # type: ignore output_folder / algo_type / "highest-output"
_save_visualisation(use_folder, name, vis) # type: ignore )
saliency_map = saliency_map_callable(
input_tensor=image,
targets=None, # type: ignore
)
_save_saliency_map(use_folder, name, saliency_map) # type: ignore
else: else:
# binary classification model with a single output - just # binary classification model with a single output - just
# lay all cams uniformily like the original dataset # lay all cams uniformily like the original dataset
use_folder = output_folder / vis_type use_folder = output_folder / algo_type
vis = cam( saliency_map = saliency_map_callable(
input_tensor=image, input_tensor=image,
targets=[ targets=[
ClassifierOutputTarget(0), # type: ignore ClassifierOutputTarget(0), # type: ignore
], ],
) )
_save_visualisation(use_folder, name, vis) # type: ignore _save_saliency_map(use_folder, name, saliency_map) # type: ignore
...@@ -26,7 +26,7 @@ MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[ ...@@ -26,7 +26,7 @@ MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[
] ]
"""A series of predictions for different database splits.""" """A series of predictions for different database splits."""
VisualisationType: typing.TypeAlias = typing.Literal[ SaliencyMapAlgorithm: typing.TypeAlias = typing.Literal[
"ablationcam", "ablationcam",
"eigencam", "eigencam",
"eigengradcam", "eigengradcam",
...@@ -41,4 +41,4 @@ VisualisationType: typing.TypeAlias = typing.Literal[ ...@@ -41,4 +41,4 @@ VisualisationType: typing.TypeAlias = typing.Literal[
"scorecam", "scorecam",
"xgradcam", "xgradcam",
] ]
"""Supported visualisation types.""" """Supported saliency map algorithms."""
...@@ -10,7 +10,7 @@ import click ...@@ -10,7 +10,7 @@ import click
from clapper.click import ResourceOption, verbosity_option from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from ..models.typing import VisualisationType from ..models.typing import SaliencyMapAlgorithm
from .click import ConfigCommand from .click import ConfigCommand
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
...@@ -27,7 +27,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -27,7 +27,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
.. code:: sh .. code:: sh
ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model_final.pth --output-folder=path/to/visualisations ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model_final.pth --output-folder=path/to/output
""", """,
) )
...@@ -53,7 +53,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -53,7 +53,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.option( @click.option(
"--output-folder", "--output-folder",
"-o", "-o",
help="Path where to store the visualisations (created if does not exist)", help="Path where to store saliency maps (created if does not exist)",
required=True, required=True,
type=click.Path( type=click.Path(
exists=False, exists=False,
...@@ -62,7 +62,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -62,7 +62,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
writable=True, writable=True,
path_type=pathlib.Path, path_type=pathlib.Path,
), ),
default="visualisations", default="saliency-maps",
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
...@@ -106,11 +106,13 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -106,11 +106,13 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--visualisation-types", "--saliency-map-algorithm",
"-vt", "-s",
help="""visualisation techniques to be used. Can be called multiple times help="""Saliency map algorithm(s) to be used. Can be called multiple times
with different techniques.""", with different techniques.""",
type=click.Choice(typing.get_args(VisualisationType), case_sensitive=False), type=click.Choice(
typing.get_args(SaliencyMapAlgorithm), case_sensitive=False
),
multiple=True, multiple=True,
default=["gradcam"], default=["gradcam"],
show_default=True, show_default=True,
...@@ -149,7 +151,7 @@ def generate_saliencymaps( ...@@ -149,7 +151,7 @@ def generate_saliencymaps(
cache_samples, cache_samples,
weight, weight,
parallel, parallel,
visualisation_types, saliency_map_algorithm,
target_class, target_class,
positive_only, positive_only,
**_, **_,
...@@ -157,8 +159,8 @@ def generate_saliencymaps( ...@@ -157,8 +159,8 @@ def generate_saliencymaps(
"""Generates saliency maps for locations on input images that affected the """Generates saliency maps for locations on input images that affected the
prediction. prediction.
The quality of saliency information depends on visualisation The quality of saliency information depends on the saliency map
technique and trained model. algorithm and trained model.
""" """
from ..engine.device import DeviceManager from ..engine.device import DeviceManager
...@@ -189,7 +191,7 @@ def generate_saliencymaps( ...@@ -189,7 +191,7 @@ def generate_saliencymaps(
model=model, model=model,
datamodule=datamodule, datamodule=datamodule,
device_manager=device_manager, device_manager=device_manager,
visualisation_types=visualisation_types, saliency_map_algorithms=saliency_map_algorithm,
target_class=target_class, target_class=target_class,
positive_only=positive_only, positive_only=positive_only,
output_folder=output_folder, output_folder=output_folder,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment