# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import pathlib
import typing

import click

from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup

from ...models.typing import SaliencyMapAlgorithm
from ..click import ConfigCommand

logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")


@click.command(
    entry_point_group="ptbench.config",
    cls=ConfigCommand,
    epilog="""Examples:

1. Calculates the ROAD scores for an existing dataset configuration and stores them in .csv files:

   .. code:: sh

      ptbench saliency completeness -vv pasa tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/completeness-scores/

""",
)
@click.option(
    "--model",
    "-m",
    help="""A lightining module instance implementing the network architecture
    (not the weights, necessarily) to be used for inference.  Currently, only
    supports pasa and densenet models.""",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--datamodule",
    "-d",
    help="""A lighting data module that will be asked for prediction data
    loaders. Typically, this includes all configured splits in a datamodule,
    however this is not a requirement.  A datamodule that returns a single
    dataloader for prediction (wrapped in a dictionary) is acceptable.""",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--output-folder",
    "-o",
    help="Path where to store saliency maps (created if does not exist)",
    required=True,
    type=click.Path(
        exists=False,
        file_okay=False,
        dir_okay=True,
        writable=True,
        path_type=pathlib.Path,
    ),
    default="saliency-maps",
    cls=ResourceOption,
)
@click.option(
    "--device",
    "-x",
    help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
    show_default=True,
    required=True,
    default="cpu",
    cls=ResourceOption,
)
@click.option(
    "--cache-samples/--no-cache-samples",
    help="If set to True, loads the sample into memory, "
    "otherwise loads them at runtime.",
    required=True,
    show_default=True,
    default=False,
    cls=ResourceOption,
)
@click.option(
    "--weight",
    "-w",
    help="""Path or URL to pretrained model file (`.ckpt` extension),
    corresponding to the architecture set with `--model`.""",
    required=True,
    cls=ResourceOption,
    type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True),
)
@click.option(
    "--parallel",
    "-P",
    help="""Use multiprocessing for data loading processing: if set to -1
    (default), disables multiprocessing.  Set to 0 to enable as many data
    processing instances as processing cores available in the system.  Set to
    >= 1 to enable that many multiprocessing instances.  Note that if you
    activate this option, then you must use --device=cpu, as using a GPU
    concurrently is not supported.""",
    type=click.IntRange(min=-1),
    show_default=True,
    required=True,
    default=-1,
    cls=ResourceOption,
)
@click.option(
    "--saliency-map-algorithm",
    "-s",
    help="""Saliency map algorithm(s) to be used. Can be called multiple times
    with different techniques.""",
    type=click.Choice(
        typing.get_args(SaliencyMapAlgorithm), case_sensitive=False
    ),
    multiple=True,
    default=["gradcam"],
    show_default=True,
    cls=ResourceOption,
)
@click.option(
    "--target-class",
    "-C",
    help="""This option should only be used with multiclass models.  It
    defines the class to target for saliency estimation. Can be either set to
    "all" or "highest". "highest" (the default), means only saliency maps for
    the class with the highest activation will be generated.""",
    required=False,
    type=click.Choice(
        ["highest", "all"],
        case_sensitive=False,
    ),
    default="highest",
    cls=ResourceOption,
)
@click.option(
    "--positive-only/--no-positive-only",
    "-z/-Z",
    help="""If set, and the model chosen has a single output (binary), then
    saliency maps will only be generated for samples of the positive class.
    This option has no effect for multiclass models.""",
    default=False,
    cls=ResourceOption,
)
@click.option(
    "--percentile",
    "-e",
    help="""One or more percentiles (percent x100) integer values indicating
    the proportion of pixels to perturb in the original image to calculate both
    MoRF and LeRF scores.""",
    multiple=True,
    default=[20, 40, 60, 80],
    show_default=True,
    cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def completeness(
    model,
    datamodule,
    output_folder,
    device,
    cache_samples,
    weight,
    parallel,
    saliency_map_algorithm,
    target_class,
    positive_only,
    percentile,
    **_,
) -> None:
    """Evaluates saliency map algorithm completeness using RemOve And Debias
    (ROAD).

    For each selected saliency map algorithm, evaluates the completeness of
    explanations using the RemOve And Debias (ROAD) algorithm. The ROAD
    algorithm was first described at [ROAD-2022]_. It estimates explainability
    (in the completeness sense) of saliency mapping algorithms by substituting
    relevant pixels in the input image by a local average, re-running
    prediction on the altered image, and measuring changes in the output
    classification score when said perturbations are in place.  By substituting
    most or least relevant pixels with surrounding averages, the ROAD algorithm
    estimates the importance of such elements in the produced saliency map.  As
    2023, this measurement technique is considered to be one of the
    state-of-the-art metrics of explainability.

    This program outputs a JSON file containing the ROAD evaluations (using
    most-relevant-first, or MoRF, and least-relevant-first, or LeRF for each
    sample in the datamodule, and per saliency-mapping algorithm.  Each
    saliency-mapping algorithm yields a single JSON file with the target
    algorithm name on the ``output-folder``. Values for MoRF and LeRF represent
    averages by removing 20, 40, 60 and 80% of most or least relevant pixels
    respectively from the image, and averaging results for all these
    percentiles.

    .. note::

       This application is relatively slow when processing a large datamodule
       with many (positive) samples.
    """
    import json

    from ...engine.device import DeviceManager
    from ...engine.saliency.completeness import run

    logger.info(f"Output folder: {output_folder}")
    output_folder.mkdir(parents=True, exist_ok=True)

    if device in ("cuda", "mps") and (parallel == 0 or parallel > 1):
        raise RuntimeError(
            f"The number of multiprocessing instances is set to {parallel} and "
            f"you asked to use a GPU (device = `{device}`). The currently "
            f"implementation can only handle a single GPU.  Either disable GPU "
            f"utilisation or set the number of multiprocessing instances to "
            f"one, or disable multiprocessing entirely (ie. set it to -1)."
        )

    device_manager = DeviceManager(device)

    # batch_size must be == 1 for now (underlying code is NOT prepared to
    # treat multiple samples at once).
    datamodule.set_chunk_size(1, 1)
    datamodule.cache_samples = cache_samples
    datamodule.parallel = parallel
    datamodule.model_transforms = model.model_transforms

    datamodule.prepare_data()
    datamodule.setup(stage="predict")

    logger.info(f"Loading checkpoint from `{weight}`...")
    model = model.load_from_checkpoint(weight, strict=False)

    for algo in saliency_map_algorithm:
        logger.info(
            f"Evaluating RemOve And Debias (ROAD) average scores for "
            f"algorithm `{algo}` with percentiles "
            f"`{', '.join([str(k) for k in percentile])}`..."
        )
        results = run(
            model=model,
            datamodule=datamodule,
            device_manager=device_manager,
            saliency_map_algorithm=algo,
            target_class=target_class,
            positive_only=positive_only,
            percentiles=percentile,
            parallel=parallel,
        )

        output_json = output_folder / (algo + ".json")
        with output_json.open("w") as f:
            logger.info(f"Saving output file to `{str(output_json)}`...")
            json.dump(results, f, indent=2)