# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import click

from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup
from pytorch_grad_cam import (
    AblationCAM,
    EigenCAM,
    EigenGradCAM,
    FullGrad,
    GradCAM,
    GradCAMElementWise,
    GradCAMPlusPlus,
    HiResCAM,
    LayerCAM,
    RandomCAM,
    ScoreCAM,
    XGradCAM,
)

logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")

allowed_visualization_types = {
    "gradcam",
    "scorecam",
    "fullgrad",
    "randomcam",
    "hirescam",
    "gradcamelementwise",
    "gradcam++",
    "gradcamplusplus",
    "xgradcam",
    "ablationcam",
    "eigencam",
    "eigengradcam",
    "layercam",
}


# To ensure that the user has selected a supported visualization type
def check_vis_types(vis_types):
    if isinstance(vis_types, str):
        vis_types = [vis_types.lower()]
    else:
        vis_types = [s.lower() for s in vis_types]

    for s in vis_types:
        if not isinstance(s, str):
            raise click.BadParameter(
                "Visualization type must be a string or a list of strings"
            )
        if s not in allowed_visualization_types:
            raise click.BadParameter(
                "Visualization type must be one of: {}".format(
                    ", ".join(allowed_visualization_types)
                )
            )
    return vis_types


# CAM factory
def create_cam(vis_type, model, target_layers, use_cuda):
    if vis_type == "gradcam":
        return GradCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "scorecam":
        return ScoreCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "fullgrad":
        return FullGrad(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "randomcam":
        return RandomCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "hirescam":
        return HiResCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "gradcamelementwise":
        return GradCAMElementWise(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "gradcam++" or vis_type == "gradcamplusplus":
        return GradCAMPlusPlus(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "xgradcam":
        return XGradCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "ablationcam":
        return AblationCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "eigencam":
        return EigenCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "eigengradcam":
        return EigenGradCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    elif vis_type == "layercam":
        return LayerCAM(
            model=model, target_layers=target_layers, use_cuda=use_cuda
        )
    else:
        raise ValueError(f"Unsupported visualization type: {vis_type}")


@click.command(
    entry_point_group="ptbench.config",
    cls=ConfigCommand,
    epilog="""Examples:

\b
    1. Generates saliency maps and saves them as pickeled objects:

       .. code:: sh

          ptbench generate-saliencymaps -vv densenet tbx11k_simplified_bbox_rgb --accelerator="cuda" --weight=path/to/model_final.pth --output-folder=path/to/visualizations

""",
)
@click.option(
    "--model",
    "-m",
    help="A torch.nn.Module instance implementing the network to be evaluated",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--dataset",
    "-d",
    help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
    "to be used for generating visualizations, possibly including all pre-processing "
    "pipelines required or, optionally, a dictionary mapping string keys to "
    "torch.utils.data.dataset.Dataset instances.  All keys that do not start "
    "with an underscore (_) will be processed.",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--output-folder",
    "-o",
    help="Path where to store the visualizations (created if does not exist)",
    required=True,
    default="visualizations",
    cls=ResourceOption,
    type=click.Path(),
)
@click.option(
    "--batch-size",
    "-b",
    help="Number of samples in every batch (this parameter affects memory requirements for the network)",
    required=True,
    show_default=True,
    default=1,
    type=click.IntRange(min=1),
    cls=ResourceOption,
)
@click.option(
    "--accelerator",
    "-a",
    help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)',
    show_default=True,
    required=True,
    default="cpu",
    cls=ResourceOption,
)
@click.option(
    "--weight",
    "-w",
    help="Path or URL to pretrained model file (.ckpt extension)",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--visualization-types",
    "-vt",
    help="Visualization techniques to be used. Can be called multiple times with different techniques. Currently supported ones are: "
    '"GradCAM", "ScoreCAM", "FullGrad", "RandomCAM", "HiResCAM", "GradCAMElementWise", "GradCAMPlusPlus", "XGradCAM", "AblationCAM", '
    '"EigenCAM", "EigenGradCAM", "LayerCAM"',
    multiple=True,
    default=["GradCAM"],
    cls=ResourceOption,
)
@click.option(
    "--target-class",
    "-tc",
    help='(Use only with multi-label models) Which class to target for CAM calculation. Can be either set to "all" or "highest". "highest" is default, which means only saliency maps for the class with the highest activation will be generated.',
    type=str,
    required=False,
    default="highest",
    cls=ResourceOption,
)
@click.option(
    "--tb-positive-only",
    "-tb",
    help="If set, saliency maps will only be generated for TB positive samples.",
    is_flag=True,
    default=False,
    cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def generate_saliencymaps(
    model,
    dataset,
    output_folder,
    batch_size,
    accelerator,
    weight,
    visualization_types,
    target_class,
    tb_positive_only,
    **_,
) -> None:
    """Generates saliency maps for locations with aTB for input CXRs, depending
    on visualization technique and model."""

    import torch

    from torch.utils.data import DataLoader

    from ..engine.saliencymap_generator import run

    # Temporary solution due to transition to PyTorch Lightning
    if accelerator.startswith("cuda") or accelerator.startswith("gpu"):
        use_cuda = torch.cuda.is_available()
        device = "cuda:0" if use_cuda else "cpu"
    else:
        use_cuda = False
        device = "cpu"

    if "datadir" in dataset:
        dataset = (
            dataset["dataset"]
            if isinstance(dataset["dataset"], dict)
            else dict(test=dataset["dataset"])
        )
    else:
        dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)

    logger.info(f"Loading checkpoint from {weight}")

    # This is a temporary solution due to transition to PyTorch Lightning
    # This will not be necessary for future users of this package
    state_dict = torch.load(weight, map_location=torch.device("cpu")).pop(
        "model"
    )
    new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict)

    # This code should work for future users of this package (no guarantee)
    # model = model.load_from_checkpoint(weight, strict=False)

    model.eval()

    visualization_types = check_vis_types(visualization_types)

    model_name = model.__class__.__name__

    if model_name == "PASA":
        if "fullgrad" in visualization_types:
            raise ValueError(
                "Fullgrad visualization is not supported for the Pasa model."
            )
        target_layers = [model.fc14]  # Last non-1x1 Conv2d layer
    else:
        target_layers = [model.model_ft.features.denseblock4.denselayer16.conv2]

    for vis_type in visualization_types:
        cam = create_cam(vis_type, model, target_layers, use_cuda)

        for k, v in dataset.items():
            if k.startswith("_"):
                logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
                continue

            logger.info(f"Generating saliency maps for '{k}' set...")

            data_loader = DataLoader(
                dataset=v,
                batch_size=batch_size,
                shuffle=False,
                pin_memory=torch.cuda.is_available(),
            )

            run(
                model,
                data_loader,
                output_folder=output_folder,
                dataset_split_name=k,
                device=device,
                cam=cam,
                visualization_type=vis_type,
                target_class=target_class,
                tb_positive_only=tb_positive_only,
            )