From 0bed6e1d2c6ebb666de3232456db64b604dbc7a3 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Thu, 13 Jul 2023 14:16:40 +0200 Subject: [PATCH] Functional evaluation scripts --- src/ptbench/engine/callbacks.py | 2 +- src/ptbench/engine/evaluator.py | 154 +++++++++++++++++++++++--------- src/ptbench/scripts/evaluate.py | 83 +++++++++++------ 3 files changed, 169 insertions(+), 70 deletions(-) diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index 5adada7b..031acaae 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -407,7 +407,7 @@ class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter): ] logfile = os.path.join( - self.output_dir, f"predictions_{dataloader_name}_set.csv" + self.output_dir, f"predictions_{dataloader_name}", "predictions.csv" ) os.makedirs(os.path.dirname(logfile), exist_ok=True) diff --git a/src/ptbench/engine/evaluator.py b/src/ptbench/engine/evaluator.py index acbf7f36..c157d5fe 100644 --- a/src/ptbench/engine/evaluator.py +++ b/src/ptbench/engine/evaluator.py @@ -8,6 +8,8 @@ import logging import os import re +from typing import Iterable, Optional + import matplotlib.pyplot as plt import numpy import pandas as pd @@ -20,22 +22,22 @@ from ..utils.measure import base_measures, get_centered_maxf1 logger = logging.getLogger(__name__) -def eer_threshold(neg, pos) -> float: +def eer_threshold(neg: Iterable[float], pos: Iterable[float]) -> float: """Evaluates the EER threshold from negative and positive scores. Parameters ---------- - neg : typing.Iterable[float] + neg : Negative scores - pos : typing.Iterable[float] + pos : Positive scores Returns: - Threshold + The EER threshold value. """ from scipy.interpolate import interp1d from scipy.optimize import brentq @@ -49,8 +51,40 @@ def eer_threshold(neg, pos) -> float: return interp1d(fpr, thresholds)(eer) -def posneg(pred, gt, threshold): - """Calculates true and false positives and negatives.""" +def posneg( + pred, gt, threshold +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculates true and false positives and negatives. + + Parameters + ---------- + + pred : + Pixel-wise predictions. + + gt : + Ground-truth (annotations). + + threshold : + A particular threshold in which to calculate the performance + measures. + + Returns + ------- + + tp_tensor: + The true positive values. + + fp_tensor: + The false positive values. + + tn_tensor: + The true negative values. + + fn_tensor: + The false negative values. + """ + # threshold binary_pred = torch.gt(pred, threshold) @@ -73,38 +107,74 @@ def posneg(pred, gt, threshold): return tp_tensor, fp_tensor, tn_tensor, fn_tensor -def sample_measures_for_threshold(pred, gt, threshold): +def sample_measures_for_threshold( + pred: torch.Tensor, gt: torch.Tensor, threshold: float +) -> tuple[float, float, float, float, float]: """Calculates measures on one single sample, for a specific threshold. Parameters ---------- - pred : torch.Tensor - pixel-wise predictions - - gt : torch.Tensor - ground-truth (annotations) + pred : + Pixel-wise predictions. - threshold : float - a particular threshold in which to calculate the performance - measures + gt : + Ground-truth (annotations). + threshold : + A particular threshold in which to calculate the performance + measures. Returns ------- - precision: float - - recall: float - - specificity: float - - accuracy: float - - jaccard: float - - f1_score: float + precision : float + P, AKA positive predictive value (PPV). It corresponds arithmetically + to ``tp/(tp+fp)``. In the case ``tp+fp == 0``, this function returns + zero for precision. + + recall : float + R, AKA sensitivity, hit rate, or true positive rate (TPR). It + corresponds arithmetically to ``tp/(tp+fn)``. In the special case + where ``tp+fn == 0``, this function returns zero for recall. + + specificity : float + S, AKA selectivity or true negative rate (TNR). It + corresponds arithmetically to ``tn/(tn+fp)``. In the special case + where ``tn+fp == 0``, this function returns zero for specificity. + + accuracy : float + A, see `Accuracy + <https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers>`_. is + the proportion of correct predictions (both true positives and true + negatives) among the total number of pixels examined. It corresponds + arithmetically to ``(tp+tn)/(tp+tn+fp+fn)``. This measure includes + both true-negatives and positives in the numerator, what makes it + sensitive to data or regions without annotations. + + jaccard : float + J, see `Jaccard Index or Similarity + <https://en.wikipedia.org/wiki/Jaccard_index>`_. It corresponds + arithmetically to ``tp/(tp+fp+fn)``. In the special case where + ``tn+fp+fn == 0``, this function returns zero for the Jaccard index. + The Jaccard index depends on a TP-only numerator, similarly to the F1 + score. For regions where there are no annotations, the Jaccard index + will always be zero, irrespective of the model output. Accuracy may be + a better proxy if one needs to consider the true abscence of + annotations in a region as part of the measure. + + f1_score : float + F1, see `F1-score <https://en.wikipedia.org/wiki/F1_score>`_. It + corresponds arithmetically to ``2*P*R/(P+R)`` or ``2*tp/(2*tp+fp+fn)``. + In the special case where ``P+R == (2*tp+fp+fn) == 0``, this function + returns zero for the Jaccard index. The F1 or Dice score depends on a + TP-only numerator, similarly to the Jaccard index. For regions where + there are no annotations, the F1-score will always be zero, + irrespective of the model output. Accuracy may be a better proxy if + one needs to consider the true abscence of annotations in a region as + part of the measure. """ + tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold) # calc measures from scalars @@ -117,12 +187,12 @@ def sample_measures_for_threshold(pred, gt, threshold): def run( dataset, - name, - predictions_folder, - output_folder=None, - f1_thresh=None, - eer_thresh=None, - steps=1000, + name: str, + predictions_folder: str, + output_folder: Optional[str | None] = None, + f1_thresh: Optional[float] = None, + eer_thresh: Optional[float] = None, + steps: Optional[int] = 1000, ): """Runs inference and calculates measures. @@ -132,44 +202,46 @@ def run( dataset : py:class:`torch.utils.data.Dataset` a dataset to iterate on - name : str + name: the local name of this dataset (e.g. ``train``, or ``test``), to be used when saving measures files. - predictions_folder : str + predictions_folder: folder where predictions for the dataset images has been previously stored - output_folder : :py:class:`str`, Optional + output_folder: folder where to store results. - f1_thresh : :py:class:`float`, Optional + f1_thresh: This number should come from the training set or a separate validation set. Using a test set value may bias your analysis. This number is also used to print the a priori F1-score on the evaluated set. - eer_thresh : :py:class:`float`, Optional + eer_thresh: This number should come from the training set or a separate validation set. Using a test set value may bias your analysis. This number is used to print the a priori EER. - steps : :py:class:`float`, Optional + steps: number of threshold steps to consider when evaluating thresholds. Returns ------- - f1_threshold : float + maxf1_threshold : float Threshold to achieve the highest possible F1-score for this dataset - eer_threshold : float + post_eer_threshold : float Threshold achieving Equal Error Rate for this dataset """ + predictions_path = os.path.join( + predictions_folder, f"predictions_{name}", "predictions.csv" + ) - predictions_path = os.path.join(predictions_folder, name, "predictions.csv") if not os.path.exists(predictions_path): predictions_path = predictions_folder diff --git a/src/ptbench/scripts/evaluate.py b/src/ptbench/scripts/evaluate.py index 2ee71db0..c1c9935d 100644 --- a/src/ptbench/scripts/evaluate.py +++ b/src/ptbench/scripts/evaluate.py @@ -2,42 +2,69 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from typing import Union + import click from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup +from ..data.datamodule import CachingDataModule +from ..data.typing import DataLoader + logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") -def _validate_threshold(t, dataset): +def _validate_threshold( + threshold: Union[int, float, str], dataloader_dict: dict[str, DataLoader] +): """Validates the user threshold selection. - Returns parsed threshold. + Parameters + ---------- + + threshold: + This number is used to define positives and negatives from + probability maps, and report F1-scores (a priori). It + should either come from the training set or a separate validation set + to avoid biasing the analysis. Optionally, if you provide a multi-set + dataset as input, this may also be the name of an existing set from + which the threshold will be estimated (highest F1-score) and then + applied to the subsequent sets. This number is also used to print + the test set F1-score a priori performance + + dataloader_dict: + Dictionary of set_name: dataloader, there set_name is the name of a dataset split + and dataloader is the torch dataloader for that split. + + Returns + ------- + + The parsed threshold. """ - if t is None: + if threshold is None: return 0.5 try: # we try to convert it to float first - t = float(t) - if t < 0.0 or t > 1.0: + threshold = float(threshold) + if threshold < 0.0 or threshold > 1.0: raise ValueError("Float thresholds must be within range [0.0, 1.0]") except ValueError: # it is a bit of text - assert dataset with name is available - if not isinstance(dataset, dict): + if not isinstance(dataloader_dict, dict): raise ValueError( "Threshold should be a floating-point number " "if your provide only a single dataset for evaluation" ) - if t not in dataset: + if threshold not in dataloader_dict: raise ValueError( f"Text thresholds should match dataset names, " - f"but {t} is not available among the datasets provided (" - f"({', '.join(dataset.keys())})" + f"but {threshold} is not available among the datasets provided (" + f"({', '.join(dataloader_dict.keys())})" ) - return t + return threshold @click.command( @@ -71,13 +98,9 @@ def _validate_threshold(t, dataset): cls=ResourceOption, ) @click.option( - "--dataset", + "--datamodule", "-d", - help="A torch.utils.data.dataset.Dataset instance implementing a dataset " - "to be used for evaluation purposes, possibly including all pre-processing " - "pipelines required or, optionally, a dictionary mapping string keys to " - "torch.utils.data.dataset.Dataset instances. All keys that do not start " - "with an underscore (_) will be processed.", + help="A lighting data module containing the training and validation sets.", required=True, cls=ResourceOption, ) @@ -109,11 +132,11 @@ def _validate_threshold(t, dataset): ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def evaluate( - output_folder, - predictions_folder, - dataset, - threshold, - steps, + output_folder: str, + predictions_folder: str, + datamodule: CachingDataModule, + threshold: Union[int, float, str], + steps: int, **_, ) -> None: """Evaluates a CNN on a tuberculosis prediction task. @@ -123,17 +146,22 @@ def evaluate( from ..engine.evaluator import run - threshold = _validate_threshold(threshold, dataset) + datamodule.prepare_data() + datamodule.setup(stage="test") + + datamodule.set_chunk_size(1, 1) - if not isinstance(dataset, dict): - dataset = {"test": dataset} + dataloader = datamodule.test_dataloader() + + threshold = _validate_threshold(threshold, dataloader) if isinstance(threshold, str): # first run evaluation for reference dataset logger.info(f"Evaluating threshold on '{threshold}' set") f1_threshold, eer_threshold = run( - dataset[threshold], threshold, predictions_folder, steps=steps + _, threshold, predictions_folder, steps=steps ) + if (f1_threshold is not None) and (eer_threshold is not None): logger.info(f"Set --f1_threshold={f1_threshold:.5f}") logger.info(f"Set --eer_threshold={eer_threshold:.5f}") @@ -144,12 +172,11 @@ def evaluate( else: raise ValueError("Threshold value is neither an int nor a float") - # now run with the - for k, v in dataset.items(): + for k, v in dataloader.items(): if k.startswith("_"): logger.info(f"Skipping dataset '{k}' (not to be evaluated)") continue - logger.info(f"Analyzing '{k}' set...") + logger.info(f"Analyzing '{threshold}' set...") run( v, k, -- GitLab