# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""Defines functionality for the evaluation of predictions."""

import logging
import os
import re

import matplotlib.pyplot as plt
import numpy
import pandas as pd
import torch

from sklearn import metrics

from ..utils.measure import base_measures, get_centered_maxf1

logger = logging.getLogger(__name__)


def eer_threshold(neg, pos) -> float:
    """Evaluates the EER threshold from negative and positive scores.

    Parameters
    ----------

        neg : typing.Iterable[float]
            Negative scores

        pos : typing.Iterable[float]
            Positive scores


    Returns:

        Threshold
    """

    from scipy.interpolate import interp1d
    from scipy.optimize import brentq

    y_predictions = pd.concat((neg, pos))
    y_true = numpy.concatenate((numpy.zeros_like(neg), numpy.ones_like(pos)))

    fpr, tpr, thresholds = metrics.roc_curve(y_true, y_predictions, pos_label=1)

    eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
    return interp1d(fpr, thresholds)(eer)


def posneg(pred, gt, threshold):
    """Calculates true and false positives and negatives."""

    # threshold
    binary_pred = torch.gt(pred, threshold)

    # equals and not-equals
    equals = torch.eq(binary_pred, gt).type(torch.uint8)
    notequals = torch.ne(binary_pred, gt).type(torch.uint8)

    # true positives
    tp_tensor = (gt * binary_pred).type(torch.uint8)

    # false positives
    fp_tensor = torch.eq((binary_pred + tp_tensor), 1).type(torch.uint8)

    # true negatives
    tn_tensor = (equals - tp_tensor).type(torch.uint8)

    # false negatives
    fn_tensor = notequals - fp_tensor.type(torch.uint8)

    return tp_tensor, fp_tensor, tn_tensor, fn_tensor


def sample_measures_for_threshold(pred, gt, threshold):
    """Calculates measures on one single sample, for a specific threshold.

    Parameters
    ----------

    pred : torch.Tensor
        pixel-wise predictions

    gt : torch.Tensor
        ground-truth (annotations)

    threshold : float
        a particular threshold in which to calculate the performance
        measures


    Returns
    -------

    precision: float

    recall: float

    specificity: float

    accuracy: float

    jaccard: float

    f1_score: float
    """

    tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold)

    # calc measures from scalars
    tp_count = torch.sum(tp_tensor).item()
    fp_count = torch.sum(fp_tensor).item()
    tn_count = torch.sum(tn_tensor).item()
    fn_count = torch.sum(fn_tensor).item()
    return base_measures(tp_count, fp_count, tn_count, fn_count)


def run(
    dataset,
    name,
    predictions_folder,
    output_folder=None,
    f1_thresh=None,
    eer_thresh=None,
    steps=1000,
):
    """Runs inference and calculates measures.

    Parameters
    ---------

    dataset : py:class:`torch.utils.data.Dataset`
        a dataset to iterate on

    name : str
        the local name of this dataset (e.g. ``train``, or ``test``), to be
        used when saving measures files.

    predictions_folder : str
        folder where predictions for the dataset images has been previously
        stored

    output_folder : :py:class:`str`, Optional
        folder where to store results.

    f1_thresh : :py:class:`float`, Optional
        This number should come from
        the training set or a separate validation set.  Using a test set value
        may bias your analysis.  This number is also used to print the a priori
        F1-score on the evaluated set.

    eer_thresh : :py:class:`float`, Optional
        This number should come from
        the training set or a separate validation set.  Using a test set value
        may bias your analysis.  This number is used to print the a priori
        EER.

    steps : :py:class:`float`, Optional
        number of threshold steps to consider when evaluating thresholds.


    Returns
    -------

    f1_threshold : float
        Threshold to achieve the highest possible F1-score for this dataset

    eer_threshold : float
        Threshold achieving Equal Error Rate for this dataset
    """

    predictions_path = os.path.join(predictions_folder, name, "predictions.csv")
    if not os.path.exists(predictions_path):
        predictions_path = predictions_folder

    # Load predictions
    pred_data = pd.read_csv(predictions_path)
    pred = torch.Tensor(
        [
            eval(re.sub(" +", " ", x.replace("\n", "")).replace(" ", ","))
            for x in pred_data["likelihood"].values
        ]
    ).double()
    gt = torch.Tensor(
        [
            eval(re.sub(" +", " ", x.replace("\n", "")).replace(" ", ","))
            for x in pred_data["ground_truth"].values
        ]
    ).double()

    if pred.shape[1] == 1 and gt.shape[1] == 1:
        pred = torch.flatten(pred)
        gt = torch.flatten(gt)

    pred_data["likelihood"] = pred
    pred_data["ground_truth"] = gt

    # Multiclass f1 score computation
    if pred.ndim > 1:
        auc = metrics.roc_auc_score(gt, pred)
        logger.info("Evaluating multiclass classification")
        logger.info(f"AUC: {auc}")
        logger.info("F1 and EER are not implemented for multiclass")

        return None, None

    # Generate measures for each threshold
    step_size = 1.0 / steps
    data = [
        (index, threshold) + sample_measures_for_threshold(pred, gt, threshold)
        for index, threshold in enumerate(numpy.arange(0.0, 1.0, step_size))
    ]

    data_df = pd.DataFrame(
        data,
        columns=(
            "index",
            "threshold",
            "precision",
            "recall",
            "specificity",
            "accuracy",
            "jaccard",
            "f1_score",
        ),
    )
    data_df = data_df.set_index("index")

    # Save evaluation csv
    if output_folder is not None:
        fullpath = os.path.join(output_folder, f"{name}.csv")
        logger.info(f"Saving {fullpath}...")
        os.makedirs(os.path.dirname(fullpath), exist_ok=True)
        data_df.to_csv(fullpath)

    # Find max F1 score
    f1_scores = numpy.asarray(data_df["f1_score"])
    thresholds = numpy.asarray(data_df["threshold"])

    maxf1, maxf1_threshold = get_centered_maxf1(f1_scores, thresholds)

    logger.info(
        f"Maximum F1-score of {maxf1:.5f}, achieved at "
        f"threshold {maxf1_threshold:.3f} (chosen *a posteriori*)"
    )

    # Find EER
    neg_gt = pred_data.loc[pred_data.loc[:, "ground_truth"] == 0, :]
    pos_gt = pred_data.loc[pred_data.loc[:, "ground_truth"] == 1, :]
    post_eer_threshold = eer_threshold(
        neg_gt["likelihood"], pos_gt["likelihood"]
    )

    logger.info(
        f"Equal error rate achieved at "
        f"threshold {post_eer_threshold:.3f} (chosen *a posteriori*)"
    )

    # Save score table
    if output_folder is not None:
        fig, axes = plt.subplots(1)
        fig.tight_layout(pad=3.0)

        # Names and bounds
        axes.set_xlabel("Score")
        axes.set_ylabel("Normalized counts")
        axes.set_xlim(0.0, 1.0)

        neg_weights = numpy.ones_like(neg_gt["likelihood"]) / len(
            pred_data["likelihood"]
        )
        pos_weights = numpy.ones_like(pos_gt["likelihood"]) / len(
            pred_data["likelihood"]
        )

        axes.hist(
            [neg_gt["likelihood"], pos_gt["likelihood"]],
            weights=[neg_weights, pos_weights],
            bins=100,
            color=["tab:blue", "tab:orange"],
            label=["Negatives", "Positives"],
        )
        axes.legend(prop={"size": 10}, loc="upper center")
        axes.set_title(f"Score table for {name} subset")

        # we should see some of axes 1 axes
        axes.spines["right"].set_visible(False)
        axes.spines["top"].set_visible(False)
        axes.spines["left"].set_position(("data", -0.015))

        fullpath = os.path.join(output_folder, f"{name}_score_table.pdf")
        fig.savefig(fullpath)

    if f1_thresh is not None and eer_thresh is not None:

        # get the closest possible threshold we have
        index = int(round(steps * f1_thresh))
        f1_a_priori = data_df["f1_score"][index]
        actual_threshold = data_df["threshold"][index]

        logger.info(
            f"F1-score of {f1_a_priori:.5f}, at threshold "
            f"{actual_threshold:.3f} (chosen *a priori*)"
        )

        # Print the a priori EER threshold
        logger.info(f"Equal error rate (chosen *a priori*) {eer_thresh:.3f}")

    return maxf1_threshold, post_eer_threshold