# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import pathlib

import click

from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup

from .click import ConfigCommand

logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")


@click.command(
    entry_point_group="ptbench.config",
    cls=ConfigCommand,
    epilog="""Examples:

1. Runs prediction on an existing datamodule configuration:

   .. code:: sh

      ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json

2. Enables multi-processing data loading with 6 processes:

   .. code:: sh

      ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json

""",
)
@click.option(
    "--output",
    "-o",
    help="""Path where to store the JSON predictions for all samples in the
    input datamodule (leading directories are created if they do not not
    exist).""",
    required=True,
    default="results",
    cls=ResourceOption,
    type=click.Path(
        file_okay=True, dir_okay=False, writable=True, path_type=pathlib.Path
    ),
)
@click.option(
    "--model",
    "-m",
    help="""A lightining module instance implementing the network architecture
    (not the weights, necessarily) to be used for prediction.""",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--datamodule",
    "-d",
    help="""A lighting data module that will be asked for prediction data
    loaders. Typically, this includes all configured splits in a datamodule,
    however this is not a requirement.  A datamodule that returns a single
    dataloader for prediction (wrapped in a dictionary) is acceptable.""",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--batch-size",
    "-b",
    help="""Number of samples in every batch (this parameter affects memory
    requirements for the network).""",
    required=True,
    show_default=True,
    default=1,
    type=click.IntRange(min=1),
    cls=ResourceOption,
)
@click.option(
    "--device",
    "-d",
    help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
    show_default=True,
    required=True,
    default="cpu",
    cls=ResourceOption,
)
@click.option(
    "--weight",
    "-w",
    help="""Path or URL to pretrained model file (`.ckpt` extension),
    corresponding to the architecture set with `--model`.""",
    required=True,
    cls=ResourceOption,
    type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True),
)
@click.option(
    "--parallel",
    "-P",
    help="""Use multiprocessing for data loading: if set to -1 (default),
    disables multiprocessing data loading.  Set to 0 to enable as many data
    loading instances as processing cores as available in the system.  Set to
    >= 1 to enable that many multiprocessing instances for data loading.""",
    type=click.IntRange(min=-1),
    show_default=True,
    required=True,
    default=-1,
    cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def predict(
    output,
    model,
    datamodule,
    batch_size,
    device,
    weight,
    parallel,
    **_,
) -> None:
    """Runs inference (generates scores) on all input images, using a pre-
    trained model."""

    import json
    import shutil

    from ..engine.device import DeviceManager
    from ..engine.predictor import run

    datamodule.set_chunk_size(batch_size, 1)
    datamodule.parallel = parallel
    datamodule.model_transforms = model.model_transforms

    datamodule.prepare_data()
    datamodule.setup(stage="predict")

    logger.info(f"Loading checkpoint from `{weight}`...")
    model = type(model).load_from_checkpoint(weight, strict=False)

    predictions = run(model, datamodule, DeviceManager(device))

    output.parent.mkdir(parents=True, exist_ok=True)
    if output.exists():
        backup = output.parent / (output.name + "~")
        logger.warning(
            f"Output predictions file `{str(output)}` exists - "
            f"backing it up to `{str(backup)}`..."
        )
        shutil.copy(output, backup)

    with output.open("w") as f:
        json.dump(predictions, f, indent=2)
    logger.info(f"Predictions saved to `{str(output)}`")