Skip to content
Snippets Groups Projects
Commit 0bed6e1d authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Functional evaluation scripts

parent 0e3f0f1c
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -407,7 +407,7 @@ class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
]
logfile = os.path.join(
self.output_dir, f"predictions_{dataloader_name}_set.csv"
self.output_dir, f"predictions_{dataloader_name}", "predictions.csv"
)
os.makedirs(os.path.dirname(logfile), exist_ok=True)
......
......@@ -8,6 +8,8 @@ import logging
import os
import re
from typing import Iterable, Optional
import matplotlib.pyplot as plt
import numpy
import pandas as pd
......@@ -20,22 +22,22 @@ from ..utils.measure import base_measures, get_centered_maxf1
logger = logging.getLogger(__name__)
def eer_threshold(neg, pos) -> float:
def eer_threshold(neg: Iterable[float], pos: Iterable[float]) -> float:
"""Evaluates the EER threshold from negative and positive scores.
Parameters
----------
neg : typing.Iterable[float]
neg :
Negative scores
pos : typing.Iterable[float]
pos :
Positive scores
Returns:
Threshold
The EER threshold value.
"""
from scipy.interpolate import interp1d
from scipy.optimize import brentq
......@@ -49,8 +51,40 @@ def eer_threshold(neg, pos) -> float:
return interp1d(fpr, thresholds)(eer)
def posneg(pred, gt, threshold):
"""Calculates true and false positives and negatives."""
def posneg(
pred, gt, threshold
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculates true and false positives and negatives.
Parameters
----------
pred :
Pixel-wise predictions.
gt :
Ground-truth (annotations).
threshold :
A particular threshold in which to calculate the performance
measures.
Returns
-------
tp_tensor:
The true positive values.
fp_tensor:
The false positive values.
tn_tensor:
The true negative values.
fn_tensor:
The false negative values.
"""
# threshold
binary_pred = torch.gt(pred, threshold)
......@@ -73,38 +107,74 @@ def posneg(pred, gt, threshold):
return tp_tensor, fp_tensor, tn_tensor, fn_tensor
def sample_measures_for_threshold(pred, gt, threshold):
def sample_measures_for_threshold(
pred: torch.Tensor, gt: torch.Tensor, threshold: float
) -> tuple[float, float, float, float, float]:
"""Calculates measures on one single sample, for a specific threshold.
Parameters
----------
pred : torch.Tensor
pixel-wise predictions
gt : torch.Tensor
ground-truth (annotations)
pred :
Pixel-wise predictions.
threshold : float
a particular threshold in which to calculate the performance
measures
gt :
Ground-truth (annotations).
threshold :
A particular threshold in which to calculate the performance
measures.
Returns
-------
precision: float
recall: float
specificity: float
accuracy: float
jaccard: float
f1_score: float
precision : float
P, AKA positive predictive value (PPV). It corresponds arithmetically
to ``tp/(tp+fp)``. In the case ``tp+fp == 0``, this function returns
zero for precision.
recall : float
R, AKA sensitivity, hit rate, or true positive rate (TPR). It
corresponds arithmetically to ``tp/(tp+fn)``. In the special case
where ``tp+fn == 0``, this function returns zero for recall.
specificity : float
S, AKA selectivity or true negative rate (TNR). It
corresponds arithmetically to ``tn/(tn+fp)``. In the special case
where ``tn+fp == 0``, this function returns zero for specificity.
accuracy : float
A, see `Accuracy
<https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers>`_. is
the proportion of correct predictions (both true positives and true
negatives) among the total number of pixels examined. It corresponds
arithmetically to ``(tp+tn)/(tp+tn+fp+fn)``. This measure includes
both true-negatives and positives in the numerator, what makes it
sensitive to data or regions without annotations.
jaccard : float
J, see `Jaccard Index or Similarity
<https://en.wikipedia.org/wiki/Jaccard_index>`_. It corresponds
arithmetically to ``tp/(tp+fp+fn)``. In the special case where
``tn+fp+fn == 0``, this function returns zero for the Jaccard index.
The Jaccard index depends on a TP-only numerator, similarly to the F1
score. For regions where there are no annotations, the Jaccard index
will always be zero, irrespective of the model output. Accuracy may be
a better proxy if one needs to consider the true abscence of
annotations in a region as part of the measure.
f1_score : float
F1, see `F1-score <https://en.wikipedia.org/wiki/F1_score>`_. It
corresponds arithmetically to ``2*P*R/(P+R)`` or ``2*tp/(2*tp+fp+fn)``.
In the special case where ``P+R == (2*tp+fp+fn) == 0``, this function
returns zero for the Jaccard index. The F1 or Dice score depends on a
TP-only numerator, similarly to the Jaccard index. For regions where
there are no annotations, the F1-score will always be zero,
irrespective of the model output. Accuracy may be a better proxy if
one needs to consider the true abscence of annotations in a region as
part of the measure.
"""
tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold)
# calc measures from scalars
......@@ -117,12 +187,12 @@ def sample_measures_for_threshold(pred, gt, threshold):
def run(
dataset,
name,
predictions_folder,
output_folder=None,
f1_thresh=None,
eer_thresh=None,
steps=1000,
name: str,
predictions_folder: str,
output_folder: Optional[str | None] = None,
f1_thresh: Optional[float] = None,
eer_thresh: Optional[float] = None,
steps: Optional[int] = 1000,
):
"""Runs inference and calculates measures.
......@@ -132,44 +202,46 @@ def run(
dataset : py:class:`torch.utils.data.Dataset`
a dataset to iterate on
name : str
name:
the local name of this dataset (e.g. ``train``, or ``test``), to be
used when saving measures files.
predictions_folder : str
predictions_folder:
folder where predictions for the dataset images has been previously
stored
output_folder : :py:class:`str`, Optional
output_folder:
folder where to store results.
f1_thresh : :py:class:`float`, Optional
f1_thresh:
This number should come from
the training set or a separate validation set. Using a test set value
may bias your analysis. This number is also used to print the a priori
F1-score on the evaluated set.
eer_thresh : :py:class:`float`, Optional
eer_thresh:
This number should come from
the training set or a separate validation set. Using a test set value
may bias your analysis. This number is used to print the a priori
EER.
steps : :py:class:`float`, Optional
steps:
number of threshold steps to consider when evaluating thresholds.
Returns
-------
f1_threshold : float
maxf1_threshold : float
Threshold to achieve the highest possible F1-score for this dataset
eer_threshold : float
post_eer_threshold : float
Threshold achieving Equal Error Rate for this dataset
"""
predictions_path = os.path.join(
predictions_folder, f"predictions_{name}", "predictions.csv"
)
predictions_path = os.path.join(predictions_folder, name, "predictions.csv")
if not os.path.exists(predictions_path):
predictions_path = predictions_folder
......
......@@ -2,42 +2,69 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
from typing import Union
import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup
from ..data.datamodule import CachingDataModule
from ..data.typing import DataLoader
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _validate_threshold(t, dataset):
def _validate_threshold(
threshold: Union[int, float, str], dataloader_dict: dict[str, DataLoader]
):
"""Validates the user threshold selection.
Returns parsed threshold.
Parameters
----------
threshold:
This number is used to define positives and negatives from
probability maps, and report F1-scores (a priori). It
should either come from the training set or a separate validation set
to avoid biasing the analysis. Optionally, if you provide a multi-set
dataset as input, this may also be the name of an existing set from
which the threshold will be estimated (highest F1-score) and then
applied to the subsequent sets. This number is also used to print
the test set F1-score a priori performance
dataloader_dict:
Dictionary of set_name: dataloader, there set_name is the name of a dataset split
and dataloader is the torch dataloader for that split.
Returns
-------
The parsed threshold.
"""
if t is None:
if threshold is None:
return 0.5
try:
# we try to convert it to float first
t = float(t)
if t < 0.0 or t > 1.0:
threshold = float(threshold)
if threshold < 0.0 or threshold > 1.0:
raise ValueError("Float thresholds must be within range [0.0, 1.0]")
except ValueError:
# it is a bit of text - assert dataset with name is available
if not isinstance(dataset, dict):
if not isinstance(dataloader_dict, dict):
raise ValueError(
"Threshold should be a floating-point number "
"if your provide only a single dataset for evaluation"
)
if t not in dataset:
if threshold not in dataloader_dict:
raise ValueError(
f"Text thresholds should match dataset names, "
f"but {t} is not available among the datasets provided ("
f"({', '.join(dataset.keys())})"
f"but {threshold} is not available among the datasets provided ("
f"({', '.join(dataloader_dict.keys())})"
)
return t
return threshold
@click.command(
......@@ -71,13 +98,9 @@ def _validate_threshold(t, dataset):
cls=ResourceOption,
)
@click.option(
"--dataset",
"--datamodule",
"-d",
help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
"to be used for evaluation purposes, possibly including all pre-processing "
"pipelines required or, optionally, a dictionary mapping string keys to "
"torch.utils.data.dataset.Dataset instances. All keys that do not start "
"with an underscore (_) will be processed.",
help="A lighting data module containing the training and validation sets.",
required=True,
cls=ResourceOption,
)
......@@ -109,11 +132,11 @@ def _validate_threshold(t, dataset):
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def evaluate(
output_folder,
predictions_folder,
dataset,
threshold,
steps,
output_folder: str,
predictions_folder: str,
datamodule: CachingDataModule,
threshold: Union[int, float, str],
steps: int,
**_,
) -> None:
"""Evaluates a CNN on a tuberculosis prediction task.
......@@ -123,17 +146,22 @@ def evaluate(
from ..engine.evaluator import run
threshold = _validate_threshold(threshold, dataset)
datamodule.prepare_data()
datamodule.setup(stage="test")
datamodule.set_chunk_size(1, 1)
if not isinstance(dataset, dict):
dataset = {"test": dataset}
dataloader = datamodule.test_dataloader()
threshold = _validate_threshold(threshold, dataloader)
if isinstance(threshold, str):
# first run evaluation for reference dataset
logger.info(f"Evaluating threshold on '{threshold}' set")
f1_threshold, eer_threshold = run(
dataset[threshold], threshold, predictions_folder, steps=steps
_, threshold, predictions_folder, steps=steps
)
if (f1_threshold is not None) and (eer_threshold is not None):
logger.info(f"Set --f1_threshold={f1_threshold:.5f}")
logger.info(f"Set --eer_threshold={eer_threshold:.5f}")
......@@ -144,12 +172,11 @@ def evaluate(
else:
raise ValueError("Threshold value is neither an int nor a float")
# now run with the
for k, v in dataset.items():
for k, v in dataloader.items():
if k.startswith("_"):
logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
continue
logger.info(f"Analyzing '{k}' set...")
logger.info(f"Analyzing '{threshold}' set...")
run(
v,
k,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment