# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import pathlib
import typing

import click

from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup

from ..models.typing import SaliencyMapAlgorithm
from .click import ConfigCommand

logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")


@click.command(
    entry_point_group="ptbench.config",
    cls=ConfigCommand,
    epilog="""Examples:

1. Generates saliency maps for all prediction dataloaders on a datamodule,
   using a pre-trained DenseNet model, and saves them as numpy-pickeled
   objects on the output directory:

   .. code:: sh

      ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model_final.ckpt --output-folder=path/to/output

""",
)
@click.option(
    "--model",
    "-m",
    help="""A lightining module instance implementing the network architecture
    (not the weights, necessarily) to be used for inference.  Currently, only
    supports pasa and densenet models.""",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--datamodule",
    "-d",
    help="""A lighting data module that will be asked for prediction data
    loaders. Typically, this includes all configured splits in a datamodule,
    however this is not a requirement.  A datamodule that returns a single
    dataloader for prediction (wrapped in a dictionary) is acceptable.""",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--output-folder",
    "-o",
    help="Path where to store saliency maps (created if does not exist)",
    required=True,
    type=click.Path(
        exists=False,
        file_okay=False,
        dir_okay=True,
        writable=True,
        path_type=pathlib.Path,
    ),
    default="saliency-maps",
    cls=ResourceOption,
)
@click.option(
    "--device",
    "-x",
    help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
    show_default=True,
    required=True,
    default="cpu",
    cls=ResourceOption,
)
@click.option(
    "--cache-samples/--no-cache-samples",
    help="If set to True, loads the sample into memory, "
    "otherwise loads them at runtime.",
    required=True,
    show_default=True,
    default=False,
    cls=ResourceOption,
)
@click.option(
    "--weight",
    "-w",
    help="""Path or URL to pretrained model file (`.ckpt` extension),
    corresponding to the architecture set with `--model`.""",
    required=True,
    cls=ResourceOption,
    type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True),
)
@click.option(
    "--parallel",
    "-P",
    help="""Use multiprocessing for data loading: if set to -1 (default),
    disables multiprocessing data loading.  Set to 0 to enable as many data
    loading instances as processing cores as available in the system.  Set to
    >= 1 to enable that many multiprocessing instances for data loading.""",
    type=click.IntRange(min=-1),
    show_default=True,
    required=True,
    default=-1,
    cls=ResourceOption,
)
@click.option(
    "--saliency-map-algorithm",
    "-s",
    help="""Saliency map algorithm(s) to be used. Can be called multiple times
    with different techniques.""",
    type=click.Choice(
        typing.get_args(SaliencyMapAlgorithm), case_sensitive=False
    ),
    multiple=True,
    default=["gradcam"],
    show_default=True,
    cls=ResourceOption,
)
@click.option(
    "--target-class",
    "-C",
    help="""This option should only be used with multiclass models.  It
    defines the class to target for saliency estimation. Can be either set to
    "all" or "highest". "highest" (the default), means only saliency maps for
    the class with the highest activation will be generated.""",
    required=False,
    type=click.Choice(
        ["highest", "all"],
        case_sensitive=False,
    ),
    default="highest",
    cls=ResourceOption,
)
@click.option(
    "--positive-only/--no-positive-only",
    "-z/-Z",
    help="""If set, and the model chosen has a single output (binary), then
    saliency maps will only be generated for samples of the positive class.
    This option has no effect for multiclass models.""",
    default=False,
    cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def generate_saliencymaps(
    model,
    datamodule,
    output_folder,
    device,
    cache_samples,
    weight,
    parallel,
    saliency_map_algorithm,
    target_class,
    positive_only,
    **_,
) -> None:
    """Generates saliency maps for locations on input images that affected the
    prediction.

    The quality of saliency information depends on the saliency map
    algorithm and trained model.
    """

    from ..engine.device import DeviceManager
    from ..engine.saliency.generator import run

    logger.info(f"Output folder: {output_folder}")
    output_folder.mkdir(parents=True, exist_ok=True)

    device_manager = DeviceManager(device)

    # batch_size must be == 1 for now (underlying code is NOT prepared to
    # treat multiple samples at once).
    datamodule.set_chunk_size(1, 1)
    datamodule.cache_samples = cache_samples
    datamodule.parallel = parallel
    datamodule.model_transforms = model.model_transforms

    datamodule.prepare_data()
    datamodule.setup(stage="predict")

    logger.info(f"Loading checkpoint from `{weight}`...")
    model = model.load_from_checkpoint(weight, strict=False)

    run(
        model=model,
        datamodule=datamodule,
        device_manager=device_manager,
        saliency_map_algorithms=saliency_map_algorithm,
        target_class=target_class,
        positive_only=positive_only,
        output_folder=output_folder,
    )