Skip to content
Snippets Groups Projects
generate_saliencymaps.py 8.37 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import click
import pathlib

from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup
from pytorch_grad_cam import (
    AblationCAM,
    EigenCAM,
    EigenGradCAM,
    FullGrad,
    GradCAM,
    GradCAMElementWise,
    GradCAMPlusPlus,
    HiResCAM,
    LayerCAM,
    RandomCAM,
    ScoreCAM,
    XGradCAM,
)

logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")

allowed_visualization_types = {
    "gradcam",
    "scorecam",
    "fullgrad",
    "randomcam",
    "hirescam",
    "gradcamelementwise",
    "gradcam++",
    "gradcamplusplus",
    "xgradcam",
    "ablationcam",
    "eigencam",
    "eigengradcam",
    "layercam",
}


# To ensure that the user has selected a supported visualization type
def check_vis_types(vis_types):
    if isinstance(vis_types, str):
        vis_types = [vis_types.lower()]
    else:
        vis_types = [s.lower() for s in vis_types]

    for s in vis_types:
        if not isinstance(s, str):
            raise click.BadParameter(
                "Visualization type must be a string or a list of strings"
            )
        if s not in allowed_visualization_types:
            raise click.BadParameter(
                "Visualization type must be one of: {}".format(
                    ", ".join(allowed_visualization_types)
                )
            )
    return vis_types


# CAM factory
def create_cam(vis_type, model, target_layers, use_cuda):
    if vis_type == "gradcam":
        return GradCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "scorecam":
        return ScoreCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "fullgrad":
        return FullGrad(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "randomcam":
        return RandomCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "hirescam":
        return HiResCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "gradcamelementwise":
        return GradCAMElementWise(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "gradcam++" or vis_type == "gradcamplusplus":
        return GradCAMPlusPlus(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "xgradcam":
        return XGradCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "ablationcam":
        return AblationCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "eigencam":
        return EigenCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "eigengradcam":
        return EigenGradCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "layercam":
        return LayerCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    else:
        raise ValueError(f"Unsupported visualization type: {vis_type}")


@click.command(
    entry_point_group="ptbench.config",
    cls=ConfigCommand,
    epilog="""Examples:

\b
    1. Generates saliency maps and saves them as pickeled objects:

       .. code:: sh

          ptbench generate-saliencymaps -vv densenet tbx11k_simplified_bbox_rgb --device="cuda" --weight=path/to/model_final.pth --output-folder=path/to/visualizations

""",
)
@click.option(
    "--model",
    "-m",
    help="A lightining module instance implementing the network to be trained.",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--datamodule",
    "-d",
    help="A lighting data module containing the training and validation sets.",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--output-folder",
    "-o",
    help="Path where to store the visualizations (created if does not exist)",
    required=True,
    type=click.Path(
        file_okay=False,
        dir_okay=True,
        writable=True,
        path_type=pathlib.Path,
    ),
    default="visualizations",
    cls=ResourceOption,
)
@click.option(
    "--batch-size",
    "-b",
    help="Number of samples in every batch (this parameter affects memory requirements for the network)",
    required=True,
    show_default=True,
    default=1,
    type=click.IntRange(min=1),
    cls=ResourceOption,
)
@click.option(
    "--device",
    "-x",
    help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
    show_default=True,
    required=True,
    default="cpu",
    cls=ResourceOption,
)
@click.option(
    "--weight",
    "-w",
    help="""Path or URL to pretrained model file (`.ckpt` extension),
    corresponding to the architecture set with `--model`.""",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--visualization-types",
    "-vt",
    help="Visualization techniques to be used. Can be called multiple times with different techniques. Currently supported ones are: "
    '"GradCAM", "ScoreCAM", "FullGrad", "RandomCAM", "HiResCAM", "GradCAMElementWise", "GradCAMPlusPlus", "XGradCAM", "AblationCAM", '
    '"EigenCAM", "EigenGradCAM", "LayerCAM"',
    multiple=True,
    default=["GradCAM"],
    cls=ResourceOption,
)
@click.option(
    "--target-class",
    "-tc",
    help='(Use only with multi-label models) Which class to target for CAM calculation. Can be either set to "all" or "highest". "highest" is default, which means only saliency maps for the class with the highest activation will be generated.',
    type=str,
    required=False,
    default="highest",
    cls=ResourceOption,
)
@click.option(
    "--tb-positive-only",
    "-tb",
    help="If set, saliency maps will only be generated for TB positive samples.",
    is_flag=True,
    default=False,
    cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def generate_saliencymaps(
    model,
    datamodule,
    output_folder,
    batch_size,
    device,
    weight,
    visualization_types,
    target_class,
    tb_positive_only,
    **_,
) -> None:
    """Generates saliency maps for locations with aTB for input CXRs, depending
    on visualization technique and model."""

    from ..engine.device import DeviceManager
    from .utils import save_sh_command
    from ..engine.saliencymap_generator import run

    save_sh_command(output_folder / "command.sh")

    device_manager = DeviceManager(device)
    device = device_manager.torch_device()
    use_cuda = device_manager.device_type == "cuda"

    datamodule.set_chunk_size(batch_size, 1)
    datamodule.drop_incomplete_batch = False
    #datamodule.cache_samples = cache_samples
    #datamodule.parallel = parallel
    datamodule.model_transforms = model.model_transforms

    datamodule.prepare_data()
    datamodule.setup(stage="predict")

    dataloaders = datamodule.predict_dataloader()

    logger.info(f"Loading checkpoint from {weight}")

    model = model.load_from_checkpoint(weight, strict=False)

    visualization_types = check_vis_types(visualization_types)

    model_name = model.__class__.__name__

    if model_name == "Pasa":
        if "fullgrad" in visualization_types:
            raise ValueError(
                "Fullgrad visualization is not supported for the Pasa model."
            )
        target_layers = [model.fc14]  # Last non-1x1 Conv2d layer
    else:
        # If this does not work out of the box for densenet, then print(model) and find the correct path to the layer I was trying to use here
        target_layers = [model.denseblock4.denselayer16.conv2]

    for vis_type in visualization_types:
        cam = create_cam(vis_type, model, target_layers, use_cuda)

        for k, v in dataloaders.items():
            logger.info(f"Generating saliency maps for '{k}' set...")

            data_loader = v

            run(
                model,
                data_loader,
                output_folder=output_folder,
                dataset_split_name=k,
                device=device,
                cam=cam,
                visualization_type=vis_type,
                target_class=target_class,
                tb_positive_only=tb_positive_only,
            )