From 0bed6e1d2c6ebb666de3232456db64b604dbc7a3 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Thu, 13 Jul 2023 14:16:40 +0200
Subject: [PATCH] Functional evaluation scripts

---
 src/ptbench/engine/callbacks.py |   2 +-
 src/ptbench/engine/evaluator.py | 154 +++++++++++++++++++++++---------
 src/ptbench/scripts/evaluate.py |  83 +++++++++++------
 3 files changed, 169 insertions(+), 70 deletions(-)

diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index 5adada7b..031acaae 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -407,7 +407,7 @@ class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
         ]
 
         logfile = os.path.join(
-            self.output_dir, f"predictions_{dataloader_name}_set.csv"
+            self.output_dir, f"predictions_{dataloader_name}", "predictions.csv"
         )
         os.makedirs(os.path.dirname(logfile), exist_ok=True)
 
diff --git a/src/ptbench/engine/evaluator.py b/src/ptbench/engine/evaluator.py
index acbf7f36..c157d5fe 100644
--- a/src/ptbench/engine/evaluator.py
+++ b/src/ptbench/engine/evaluator.py
@@ -8,6 +8,8 @@ import logging
 import os
 import re
 
+from typing import Iterable, Optional
+
 import matplotlib.pyplot as plt
 import numpy
 import pandas as pd
@@ -20,22 +22,22 @@ from ..utils.measure import base_measures, get_centered_maxf1
 logger = logging.getLogger(__name__)
 
 
-def eer_threshold(neg, pos) -> float:
+def eer_threshold(neg: Iterable[float], pos: Iterable[float]) -> float:
     """Evaluates the EER threshold from negative and positive scores.
 
     Parameters
     ----------
 
-        neg : typing.Iterable[float]
+        neg :
             Negative scores
 
-        pos : typing.Iterable[float]
+        pos :
             Positive scores
 
 
     Returns:
 
-        Threshold
+        The EER threshold value.
     """
     from scipy.interpolate import interp1d
     from scipy.optimize import brentq
@@ -49,8 +51,40 @@ def eer_threshold(neg, pos) -> float:
     return interp1d(fpr, thresholds)(eer)
 
 
-def posneg(pred, gt, threshold):
-    """Calculates true and false positives and negatives."""
+def posneg(
+    pred, gt, threshold
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    """Calculates true and false positives and negatives.
+
+    Parameters
+    ----------
+
+    pred :
+        Pixel-wise predictions.
+
+    gt :
+        Ground-truth (annotations).
+
+    threshold :
+        A particular threshold in which to calculate the performance
+        measures.
+
+    Returns
+    -------
+
+    tp_tensor:
+        The true positive values.
+
+    fp_tensor:
+        The false positive values.
+
+    tn_tensor:
+        The true negative values.
+
+    fn_tensor:
+        The false negative values.
+    """
+
     # threshold
     binary_pred = torch.gt(pred, threshold)
 
@@ -73,38 +107,74 @@ def posneg(pred, gt, threshold):
     return tp_tensor, fp_tensor, tn_tensor, fn_tensor
 
 
-def sample_measures_for_threshold(pred, gt, threshold):
+def sample_measures_for_threshold(
+    pred: torch.Tensor, gt: torch.Tensor, threshold: float
+) -> tuple[float, float, float, float, float]:
     """Calculates measures on one single sample, for a specific threshold.
 
     Parameters
     ----------
 
-    pred : torch.Tensor
-        pixel-wise predictions
-
-    gt : torch.Tensor
-        ground-truth (annotations)
+    pred :
+        Pixel-wise predictions.
 
-    threshold : float
-        a particular threshold in which to calculate the performance
-        measures
+    gt :
+        Ground-truth (annotations).
 
+    threshold :
+        A particular threshold in which to calculate the performance
+        measures.
 
     Returns
     -------
 
-    precision: float
-
-    recall: float
-
-    specificity: float
-
-    accuracy: float
-
-    jaccard: float
-
-    f1_score: float
+    precision : float
+        P, AKA positive predictive value (PPV).  It corresponds arithmetically
+        to ``tp/(tp+fp)``.  In the case ``tp+fp == 0``, this function returns
+        zero for precision.
+
+    recall : float
+        R, AKA sensitivity, hit rate, or true positive rate (TPR).  It
+        corresponds arithmetically to ``tp/(tp+fn)``.  In the special case
+        where ``tp+fn == 0``, this function returns zero for recall.
+
+    specificity : float
+        S, AKA selectivity or true negative rate (TNR).  It
+        corresponds arithmetically to ``tn/(tn+fp)``.  In the special case
+        where ``tn+fp == 0``, this function returns zero for specificity.
+
+    accuracy : float
+        A, see `Accuracy
+        <https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers>`_. is
+        the proportion of correct predictions (both true positives and true
+        negatives) among the total number of pixels examined.  It corresponds
+        arithmetically to ``(tp+tn)/(tp+tn+fp+fn)``.  This measure includes
+        both true-negatives and positives in the numerator, what makes it
+        sensitive to data or regions without annotations.
+
+    jaccard : float
+        J, see `Jaccard Index or Similarity
+        <https://en.wikipedia.org/wiki/Jaccard_index>`_.  It corresponds
+        arithmetically to ``tp/(tp+fp+fn)``.  In the special case where
+        ``tn+fp+fn == 0``, this function returns zero for the Jaccard index.
+        The Jaccard index depends on a TP-only numerator, similarly to the F1
+        score.  For regions where there are no annotations, the Jaccard index
+        will always be zero, irrespective of the model output.  Accuracy may be
+        a better proxy if one needs to consider the true abscence of
+        annotations in a region as part of the measure.
+
+    f1_score : float
+        F1, see `F1-score <https://en.wikipedia.org/wiki/F1_score>`_.  It
+        corresponds arithmetically to ``2*P*R/(P+R)`` or ``2*tp/(2*tp+fp+fn)``.
+        In the special case where ``P+R == (2*tp+fp+fn) == 0``, this function
+        returns zero for the Jaccard index.  The F1 or Dice score depends on a
+        TP-only numerator, similarly to the Jaccard index.  For regions where
+        there are no annotations, the F1-score will always be zero,
+        irrespective of the model output.  Accuracy may be a better proxy if
+        one needs to consider the true abscence of annotations in a region as
+        part of the measure.
     """
+
     tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold)
 
     # calc measures from scalars
@@ -117,12 +187,12 @@ def sample_measures_for_threshold(pred, gt, threshold):
 
 def run(
     dataset,
-    name,
-    predictions_folder,
-    output_folder=None,
-    f1_thresh=None,
-    eer_thresh=None,
-    steps=1000,
+    name: str,
+    predictions_folder: str,
+    output_folder: Optional[str | None] = None,
+    f1_thresh: Optional[float] = None,
+    eer_thresh: Optional[float] = None,
+    steps: Optional[int] = 1000,
 ):
     """Runs inference and calculates measures.
 
@@ -132,44 +202,46 @@ def run(
     dataset : py:class:`torch.utils.data.Dataset`
         a dataset to iterate on
 
-    name : str
+    name:
         the local name of this dataset (e.g. ``train``, or ``test``), to be
         used when saving measures files.
 
-    predictions_folder : str
+    predictions_folder:
         folder where predictions for the dataset images has been previously
         stored
 
-    output_folder : :py:class:`str`, Optional
+    output_folder:
         folder where to store results.
 
-    f1_thresh : :py:class:`float`, Optional
+    f1_thresh:
         This number should come from
         the training set or a separate validation set.  Using a test set value
         may bias your analysis.  This number is also used to print the a priori
         F1-score on the evaluated set.
 
-    eer_thresh : :py:class:`float`, Optional
+    eer_thresh:
         This number should come from
         the training set or a separate validation set.  Using a test set value
         may bias your analysis.  This number is used to print the a priori
         EER.
 
-    steps : :py:class:`float`, Optional
+    steps:
         number of threshold steps to consider when evaluating thresholds.
 
 
     Returns
     -------
 
-    f1_threshold : float
+    maxf1_threshold : float
         Threshold to achieve the highest possible F1-score for this dataset
 
-    eer_threshold : float
+    post_eer_threshold : float
         Threshold achieving Equal Error Rate for this dataset
     """
+    predictions_path = os.path.join(
+        predictions_folder, f"predictions_{name}", "predictions.csv"
+    )
 
-    predictions_path = os.path.join(predictions_folder, name, "predictions.csv")
     if not os.path.exists(predictions_path):
         predictions_path = predictions_folder
 
diff --git a/src/ptbench/scripts/evaluate.py b/src/ptbench/scripts/evaluate.py
index 2ee71db0..c1c9935d 100644
--- a/src/ptbench/scripts/evaluate.py
+++ b/src/ptbench/scripts/evaluate.py
@@ -2,42 +2,69 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+from typing import Union
+
 import click
 
 from clapper.click import ConfigCommand, ResourceOption, verbosity_option
 from clapper.logging import setup
 
+from ..data.datamodule import CachingDataModule
+from ..data.typing import DataLoader
+
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
-def _validate_threshold(t, dataset):
+def _validate_threshold(
+    threshold: Union[int, float, str], dataloader_dict: dict[str, DataLoader]
+):
     """Validates the user threshold selection.
 
-    Returns parsed threshold.
+    Parameters
+    ----------
+
+    threshold:
+        This number is used to define positives and negatives from
+        probability maps, and report F1-scores (a priori). It
+        should either come from the training set or a separate validation set
+        to avoid biasing the analysis.  Optionally, if you provide a multi-set
+        dataset as input, this may also be the name of an existing set from
+        which the threshold will be estimated (highest F1-score) and then
+        applied to the subsequent sets.  This number is also used to print
+        the test set F1-score a priori performance
+
+    dataloader_dict:
+        Dictionary of set_name: dataloader, there set_name is the name of a dataset split
+        and dataloader is the torch dataloader for that split.
+
+    Returns
+    -------
+
+        The parsed threshold.
     """
-    if t is None:
+    if threshold is None:
         return 0.5
 
     try:
         # we try to convert it to float first
-        t = float(t)
-        if t < 0.0 or t > 1.0:
+        threshold = float(threshold)
+        if threshold < 0.0 or threshold > 1.0:
             raise ValueError("Float thresholds must be within range [0.0, 1.0]")
     except ValueError:
         # it is a bit of text - assert dataset with name is available
-        if not isinstance(dataset, dict):
+        if not isinstance(dataloader_dict, dict):
             raise ValueError(
                 "Threshold should be a floating-point number "
                 "if your provide only a single dataset for evaluation"
             )
-        if t not in dataset:
+        if threshold not in dataloader_dict:
             raise ValueError(
                 f"Text thresholds should match dataset names, "
-                f"but {t} is not available among the datasets provided ("
-                f"({', '.join(dataset.keys())})"
+                f"but {threshold} is not available among the datasets provided ("
+                f"({', '.join(dataloader_dict.keys())})"
             )
 
-    return t
+    return threshold
 
 
 @click.command(
@@ -71,13 +98,9 @@ def _validate_threshold(t, dataset):
     cls=ResourceOption,
 )
 @click.option(
-    "--dataset",
+    "--datamodule",
     "-d",
-    help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
-    "to be used for evaluation purposes, possibly including all pre-processing "
-    "pipelines required or, optionally, a dictionary mapping string keys to "
-    "torch.utils.data.dataset.Dataset instances.  All keys that do not start "
-    "with an underscore (_) will be processed.",
+    help="A lighting data module containing the training and validation sets.",
     required=True,
     cls=ResourceOption,
 )
@@ -109,11 +132,11 @@ def _validate_threshold(t, dataset):
 )
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def evaluate(
-    output_folder,
-    predictions_folder,
-    dataset,
-    threshold,
-    steps,
+    output_folder: str,
+    predictions_folder: str,
+    datamodule: CachingDataModule,
+    threshold: Union[int, float, str],
+    steps: int,
     **_,
 ) -> None:
     """Evaluates a CNN on a tuberculosis prediction task.
@@ -123,17 +146,22 @@ def evaluate(
 
     from ..engine.evaluator import run
 
-    threshold = _validate_threshold(threshold, dataset)
+    datamodule.prepare_data()
+    datamodule.setup(stage="test")
+
+    datamodule.set_chunk_size(1, 1)
 
-    if not isinstance(dataset, dict):
-        dataset = {"test": dataset}
+    dataloader = datamodule.test_dataloader()
+
+    threshold = _validate_threshold(threshold, dataloader)
 
     if isinstance(threshold, str):
         # first run evaluation for reference dataset
         logger.info(f"Evaluating threshold on '{threshold}' set")
         f1_threshold, eer_threshold = run(
-            dataset[threshold], threshold, predictions_folder, steps=steps
+            _, threshold, predictions_folder, steps=steps
         )
+
         if (f1_threshold is not None) and (eer_threshold is not None):
             logger.info(f"Set --f1_threshold={f1_threshold:.5f}")
             logger.info(f"Set --eer_threshold={eer_threshold:.5f}")
@@ -144,12 +172,11 @@ def evaluate(
     else:
         raise ValueError("Threshold value is neither an int nor a float")
 
-    # now run with the
-    for k, v in dataset.items():
+    for k, v in dataloader.items():
         if k.startswith("_"):
             logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
             continue
-        logger.info(f"Analyzing '{k}' set...")
+        logger.info(f"Analyzing '{threshold}' set...")
         run(
             v,
             k,
-- 
GitLab