# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""Defines functionality for the evaluation of predictions."""

import logging
import os
import re

from typing import Iterable, Optional

import matplotlib.pyplot as plt
import numpy
import pandas as pd
import torch

from sklearn import metrics

from ..utils.measure import base_measures, get_centered_maxf1

logger = logging.getLogger(__name__)


def eer_threshold(neg: Iterable[float], pos: Iterable[float]) -> float:
    """Evaluates the EER threshold from negative and positive scores.

    Parameters
    ----------

        neg :
            Negative scores

        pos :
            Positive scores


    Returns:

        The EER threshold value.
    """
    from scipy.interpolate import interp1d
    from scipy.optimize import brentq

    y_predictions = pd.concat((neg, pos))
    y_true = numpy.concatenate((numpy.zeros_like(neg), numpy.ones_like(pos)))

    fpr, tpr, thresholds = metrics.roc_curve(y_true, y_predictions, pos_label=1)

    eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
    return interp1d(fpr, thresholds)(eer)


def posneg(
    pred, gt, threshold
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Calculates true and false positives and negatives.

    Parameters
    ----------

    pred :
        Pixel-wise predictions.

    gt :
        Ground-truth (annotations).

    threshold :
        A particular threshold in which to calculate the performance
        measures.

    Returns
    -------

    tp_tensor:
        The true positive values.

    fp_tensor:
        The false positive values.

    tn_tensor:
        The true negative values.

    fn_tensor:
        The false negative values.
    """

    # threshold
    binary_pred = torch.gt(pred, threshold)

    # equals and not-equals
    equals = torch.eq(binary_pred, gt).type(torch.uint8)
    notequals = torch.ne(binary_pred, gt).type(torch.uint8)

    # true positives
    tp_tensor = (gt * binary_pred).type(torch.uint8)

    # false positives
    fp_tensor = torch.eq((binary_pred + tp_tensor), 1).type(torch.uint8)

    # true negatives
    tn_tensor = (equals - tp_tensor).type(torch.uint8)

    # false negatives
    fn_tensor = notequals - fp_tensor.type(torch.uint8)

    return tp_tensor, fp_tensor, tn_tensor, fn_tensor


def sample_measures_for_threshold(
    pred: torch.Tensor, gt: torch.Tensor, threshold: float
) -> tuple[float, float, float, float, float]:
    """Calculates measures on one single sample, for a specific threshold.

    Parameters
    ----------

    pred :
        Pixel-wise predictions.

    gt :
        Ground-truth (annotations).

    threshold :
        A particular threshold in which to calculate the performance
        measures.

    Returns
    -------

    precision : float
        P, AKA positive predictive value (PPV).  It corresponds arithmetically
        to ``tp/(tp+fp)``.  In the case ``tp+fp == 0``, this function returns
        zero for precision.

    recall : float
        R, AKA sensitivity, hit rate, or true positive rate (TPR).  It
        corresponds arithmetically to ``tp/(tp+fn)``.  In the special case
        where ``tp+fn == 0``, this function returns zero for recall.

    specificity : float
        S, AKA selectivity or true negative rate (TNR).  It
        corresponds arithmetically to ``tn/(tn+fp)``.  In the special case
        where ``tn+fp == 0``, this function returns zero for specificity.

    accuracy : float
        A, see `Accuracy
        <https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers>`_. is
        the proportion of correct predictions (both true positives and true
        negatives) among the total number of pixels examined.  It corresponds
        arithmetically to ``(tp+tn)/(tp+tn+fp+fn)``.  This measure includes
        both true-negatives and positives in the numerator, what makes it
        sensitive to data or regions without annotations.

    jaccard : float
        J, see `Jaccard Index or Similarity
        <https://en.wikipedia.org/wiki/Jaccard_index>`_.  It corresponds
        arithmetically to ``tp/(tp+fp+fn)``.  In the special case where
        ``tn+fp+fn == 0``, this function returns zero for the Jaccard index.
        The Jaccard index depends on a TP-only numerator, similarly to the F1
        score.  For regions where there are no annotations, the Jaccard index
        will always be zero, irrespective of the model output.  Accuracy may be
        a better proxy if one needs to consider the true abscence of
        annotations in a region as part of the measure.

    f1_score : float
        F1, see `F1-score <https://en.wikipedia.org/wiki/F1_score>`_.  It
        corresponds arithmetically to ``2*P*R/(P+R)`` or ``2*tp/(2*tp+fp+fn)``.
        In the special case where ``P+R == (2*tp+fp+fn) == 0``, this function
        returns zero for the Jaccard index.  The F1 or Dice score depends on a
        TP-only numerator, similarly to the Jaccard index.  For regions where
        there are no annotations, the F1-score will always be zero,
        irrespective of the model output.  Accuracy may be a better proxy if
        one needs to consider the true abscence of annotations in a region as
        part of the measure.
    """

    tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold)

    # calc measures from scalars
    tp_count = torch.sum(tp_tensor).item()
    fp_count = torch.sum(fp_tensor).item()
    tn_count = torch.sum(tn_tensor).item()
    fn_count = torch.sum(fn_tensor).item()
    return base_measures(tp_count, fp_count, tn_count, fn_count)


def run(
    dataset,
    name: str,
    predictions_folder: str,
    output_folder: Optional[str | None] = None,
    f1_thresh: Optional[float] = None,
    eer_thresh: Optional[float] = None,
    steps: Optional[int] = 1000,
):
    """Runs inference and calculates measures.

    Parameters
    ---------

    dataset : py:class:`torch.utils.data.Dataset`
        a dataset to iterate on

    name:
        the local name of this dataset (e.g. ``train``, or ``test``), to be
        used when saving measures files.

    predictions_folder:
        folder where predictions for the dataset images has been previously
        stored

    output_folder:
        folder where to store results.

    f1_thresh:
        This number should come from
        the training set or a separate validation set.  Using a test set value
        may bias your analysis.  This number is also used to print the a priori
        F1-score on the evaluated set.

    eer_thresh:
        This number should come from
        the training set or a separate validation set.  Using a test set value
        may bias your analysis.  This number is used to print the a priori
        EER.

    steps:
        number of threshold steps to consider when evaluating thresholds.


    Returns
    -------

    maxf1_threshold : float
        Threshold to achieve the highest possible F1-score for this dataset

    post_eer_threshold : float
        Threshold achieving Equal Error Rate for this dataset
    """
    predictions_path = os.path.join(
        predictions_folder, f"predictions_{name}", "predictions.csv"
    )

    if not os.path.exists(predictions_path):
        predictions_path = predictions_folder

    # Load predictions
    pred_data = pd.read_csv(predictions_path)
    pred = torch.Tensor(
        [
            eval(re.sub(" +", " ", x.replace("\n", "")).replace(" ", ","))
            for x in pred_data["likelihood"].values
        ]
    ).double()
    gt = torch.Tensor(
        [
            eval(re.sub(" +", " ", x.replace("\n", "")).replace(" ", ","))
            for x in pred_data["ground_truth"].values
        ]
    ).double()

    if pred.shape[1] == 1 and gt.shape[1] == 1:
        pred = torch.flatten(pred)
        gt = torch.flatten(gt)

    pred_data["likelihood"] = pred
    pred_data["ground_truth"] = gt

    # Multiclass f1 score computation
    if pred.ndim > 1:
        auc = metrics.roc_auc_score(gt, pred)
        logger.info("Evaluating multiclass classification")
        logger.info(f"AUC: {auc}")
        logger.info("F1 and EER are not implemented for multiclass")

        return None, None

    # Generate measures for each threshold
    step_size = 1.0 / steps
    data = [
        (index, threshold) + sample_measures_for_threshold(pred, gt, threshold)
        for index, threshold in enumerate(numpy.arange(0.0, 1.0, step_size))
    ]

    data_df = pd.DataFrame(
        data,
        columns=(
            "index",
            "threshold",
            "precision",
            "recall",
            "specificity",
            "accuracy",
            "jaccard",
            "f1_score",
        ),
    )
    data_df = data_df.set_index("index")

    # Save evaluation csv
    if output_folder is not None:
        fullpath = os.path.join(output_folder, f"{name}.csv")
        logger.info(f"Saving {fullpath}...")
        os.makedirs(os.path.dirname(fullpath), exist_ok=True)
        data_df.to_csv(fullpath)

    # Find max F1 score
    f1_scores = numpy.asarray(data_df["f1_score"])
    thresholds = numpy.asarray(data_df["threshold"])

    maxf1, maxf1_threshold = get_centered_maxf1(f1_scores, thresholds)

    logger.info(
        f"Maximum F1-score of {maxf1:.5f}, achieved at "
        f"threshold {maxf1_threshold:.3f} (chosen *a posteriori*)"
    )

    # Find EER
    neg_gt = pred_data.loc[pred_data.loc[:, "ground_truth"] == 0, :]
    pos_gt = pred_data.loc[pred_data.loc[:, "ground_truth"] == 1, :]
    post_eer_threshold = eer_threshold(
        neg_gt["likelihood"], pos_gt["likelihood"]
    )

    logger.info(
        f"Equal error rate achieved at "
        f"threshold {post_eer_threshold:.3f} (chosen *a posteriori*)"
    )

    # Save score table
    if output_folder is not None:
        fig, axes = plt.subplots(1)
        fig.tight_layout(pad=3.0)

        # Names and bounds
        axes.set_xlabel("Score")
        axes.set_ylabel("Normalized counts")
        axes.set_xlim(0.0, 1.0)

        neg_weights = numpy.ones_like(neg_gt["likelihood"]) / len(
            pred_data["likelihood"]
        )
        pos_weights = numpy.ones_like(pos_gt["likelihood"]) / len(
            pred_data["likelihood"]
        )

        axes.hist(
            [neg_gt["likelihood"], pos_gt["likelihood"]],
            weights=[neg_weights, pos_weights],
            bins=100,
            color=["tab:blue", "tab:orange"],
            label=["Negatives", "Positives"],
        )
        axes.legend(prop={"size": 10}, loc="upper center")
        axes.set_title(f"Score table for {name} subset")

        # we should see some of axes 1 axes
        axes.spines["right"].set_visible(False)
        axes.spines["top"].set_visible(False)
        axes.spines["left"].set_position(("data", -0.015))

        fullpath = os.path.join(output_folder, f"{name}_score_table.pdf")
        fig.savefig(fullpath)

    if f1_thresh is not None and eer_thresh is not None:
        # get the closest possible threshold we have
        index = int(round(steps * f1_thresh))
        f1_a_priori = data_df["f1_score"][index]
        actual_threshold = data_df["threshold"][index]

        logger.info(
            f"F1-score of {f1_a_priori:.5f}, at threshold "
            f"{actual_threshold:.3f} (chosen *a priori*)"
        )

        # Print the a priori EER threshold
        logger.info(f"Equal error rate (chosen *a priori*) {eer_thresh:.3f}")

    return maxf1_threshold, post_eer_threshold