# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import click

from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup

logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")


@click.command(
    entry_point_group="ptbench.config",
    cls=ConfigCommand,
    epilog="""Examples:

\b
    1. Runs prediction on an existing dataset configuration:

       .. code:: sh

          ptbench predict -vv pasa montgomery --weight=path/to/model_final.pth --output-folder=path/to/predictions

""",
)
@click.option(
    "--output-folder",
    "-o",
    help="Path where to store the predictions (created if does not exist)",
    required=True,
    default="results",
    cls=ResourceOption,
    type=click.Path(),
)
@click.option(
    "--model",
    "-m",
    help="A torch.nn.Module instance implementing the network to be evaluated",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--datamodule",
    "-d",
    help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
    "to be used for running prediction, possibly including all pre-processing "
    "pipelines required or, optionally, a dictionary mapping string keys to "
    "torch.utils.data.dataset.Dataset instances.  All keys that do not start "
    "with an underscore (_) will be processed.",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--batch-size",
    "-b",
    help="Number of samples in every batch (this parameter affects memory requirements for the network)",
    required=True,
    show_default=True,
    default=1,
    type=click.IntRange(min=1),
    cls=ResourceOption,
)
@click.option(
    "--device",
    "-d",
    help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
    show_default=True,
    required=True,
    default="cpu",
    cls=ResourceOption,
)
@click.option(
    "--weight",
    "-w",
    help="Path or URL to pretrained model file (.ckpt extension)",
    required=True,
    cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def predict(
    output_folder,
    model,
    datamodule,
    batch_size,
    device,
    weight,
    **_,
) -> None:
    """Predicts Tuberculosis presence (probabilities) on input images."""

    import os

    import numpy as np

    from matplotlib.backends.backend_pdf import PdfPages

    from ..engine.device import DeviceManager
    from ..engine.predictor import run
    from ..utils.plot import relevance_analysis_plot

    datamodule.set_chunk_size(batch_size, 1)
    datamodule.model_transforms = model.model_transforms

    datamodule.prepare_data()
    datamodule.setup(stage="predict")

    logger.info(f"Loading checkpoint from {weight}")
    model = model.load_from_checkpoint(weight, strict=False)

    # Logistic regressor weights
    if model.name == "logistic_regression":
        logger.info("Logistic regression identified: saving model weights")
        for param in model.parameters():
            model_weights = np.array(param.data.reshape(-1))
            break
        filepath = os.path.join(output_folder, "LogReg_Weights.pdf")
        logger.info(f"Creating and saving weights plot at {filepath}...")
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        pdf = PdfPages(filepath)
        pdf.savefig(
            relevance_analysis_plot(model_weights, title="LogReg model weights")
        )
        pdf.close()

    _ = run(model, datamodule, DeviceManager(device), output_folder)