# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import os
import pathlib

import click

from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup

logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")


@click.command(
    entry_point_group="ptbench.config",
    cls=ConfigCommand,
    epilog="""Examples:

\b
    1. Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration:

       .. code:: sh

          ptbench visualize -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --road-path=parent_folder/gradcam/ --output-folder=path/to/visualizations

""",
)
@click.option(
    "--model",
    "-m",
    help="A lightining module instance implementing the network to be used for applying the necessary data transformations.",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--datamodule",
    "-d",
    help="A lighting data module containing the training, validation and test sets.",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--input-folder",
    "-i",
    help="Path to the folder containing the saliency maps for a specific visualization type.",
    required=True,
    type=click.Path(
        file_okay=False,
        dir_okay=True,
        writable=True,
        path_type=pathlib.Path,
    ),
    default="visualizations",
    cls=ResourceOption,
)
@click.option(
    "--output-folder",
    "-o",
    help="Path where to store the ROAD scores (created if does not exist)",
    required=True,
    type=click.Path(
        file_okay=False,
        dir_okay=True,
        writable=True,
        path_type=pathlib.Path,
    ),
    default="visualizations",
    cls=ResourceOption,
)
@click.option(
    "--visualize-groundtruth",
    "-vgt",
    help="If set, visualizations for ground truth labels will be generated. Only works for datasets with bounding boxes.",
    is_flag=True,
    default=False,
    cls=ResourceOption,
)
@click.option(
    "--road-path",
    "-r",
    help="If the path to the previously calculated ROAD scores is provided, MoRF, LeRF, and combined ROAD scores will be visualized on each image.",
    required=False,
    default=None,
    cls=ResourceOption,
    type=click.Path(),
)
@click.option(
    "--visualize-detected-bbox",
    "-vb",
    help="If set, largest component bounding boxes will be visualized on each image. Only works if the bounding boxes have been previously generated.",
    is_flag=True,
    default=False,
    cls=ResourceOption,
)
@click.option(
    "--threshold",
    "-t",
    help="Only activations above this threshold will be visualized.",
    show_default=True,
    required=True,
    default=0.0,
    type=click.FloatRange(min=0, max=1),
    cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def visualize(
    model,
    datamodule,
    input_folder,
    output_folder,
    visualize_groundtruth,
    road_path,
    visualize_detected_bbox,
    threshold,
    **_,
) -> None:
    """Generates heatmaps for input CXRs based on existing saliency maps."""

    from ..engine.visualizer import run
    from .utils import save_sh_command

    assert (
        input_folder != output_folder
    ), "Output folder must not be the same as the input folder."

    assert not str(output_folder).startswith(
        str(input_folder)
    ), "Output folder must not be a subdirectory of the input folder."

    logger.info(f"Output folder: {output_folder}")
    os.makedirs(output_folder, exist_ok=True)

    save_sh_command(output_folder / "command.sh")

    datamodule.set_chunk_size(1, 1)
    datamodule.drop_incomplete_batch = False
    # datamodule.cache_samples = cache_samples
    # datamodule.parallel = parallel
    datamodule.model_transforms = model.model_transforms

    datamodule.prepare_data()
    datamodule.setup(stage="predict")

    dataloaders = datamodule.predict_dataloader()

    for k, v in dataloaders.items():
        run(
            data_loader=v,
            input_folder=input_folder,
            output_folder=output_folder,
            dataset_split_name=k,
            visualize_groundtruth=visualize_groundtruth,
            road_path=road_path,
            visualize_detected_bbox=visualize_detected_bbox,
            threshold=threshold,
        )