Skip to content
Snippets Groups Projects
Commit 0bed6e1d authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Functional evaluation scripts

parent 0e3f0f1c
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -407,7 +407,7 @@ class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter): ...@@ -407,7 +407,7 @@ class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
] ]
logfile = os.path.join( logfile = os.path.join(
self.output_dir, f"predictions_{dataloader_name}_set.csv" self.output_dir, f"predictions_{dataloader_name}", "predictions.csv"
) )
os.makedirs(os.path.dirname(logfile), exist_ok=True) os.makedirs(os.path.dirname(logfile), exist_ok=True)
......
...@@ -8,6 +8,8 @@ import logging ...@@ -8,6 +8,8 @@ import logging
import os import os
import re import re
from typing import Iterable, Optional
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy import numpy
import pandas as pd import pandas as pd
...@@ -20,22 +22,22 @@ from ..utils.measure import base_measures, get_centered_maxf1 ...@@ -20,22 +22,22 @@ from ..utils.measure import base_measures, get_centered_maxf1
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def eer_threshold(neg, pos) -> float: def eer_threshold(neg: Iterable[float], pos: Iterable[float]) -> float:
"""Evaluates the EER threshold from negative and positive scores. """Evaluates the EER threshold from negative and positive scores.
Parameters Parameters
---------- ----------
neg : typing.Iterable[float] neg :
Negative scores Negative scores
pos : typing.Iterable[float] pos :
Positive scores Positive scores
Returns: Returns:
Threshold The EER threshold value.
""" """
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
from scipy.optimize import brentq from scipy.optimize import brentq
...@@ -49,8 +51,40 @@ def eer_threshold(neg, pos) -> float: ...@@ -49,8 +51,40 @@ def eer_threshold(neg, pos) -> float:
return interp1d(fpr, thresholds)(eer) return interp1d(fpr, thresholds)(eer)
def posneg(pred, gt, threshold): def posneg(
"""Calculates true and false positives and negatives.""" pred, gt, threshold
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculates true and false positives and negatives.
Parameters
----------
pred :
Pixel-wise predictions.
gt :
Ground-truth (annotations).
threshold :
A particular threshold in which to calculate the performance
measures.
Returns
-------
tp_tensor:
The true positive values.
fp_tensor:
The false positive values.
tn_tensor:
The true negative values.
fn_tensor:
The false negative values.
"""
# threshold # threshold
binary_pred = torch.gt(pred, threshold) binary_pred = torch.gt(pred, threshold)
...@@ -73,38 +107,74 @@ def posneg(pred, gt, threshold): ...@@ -73,38 +107,74 @@ def posneg(pred, gt, threshold):
return tp_tensor, fp_tensor, tn_tensor, fn_tensor return tp_tensor, fp_tensor, tn_tensor, fn_tensor
def sample_measures_for_threshold(pred, gt, threshold): def sample_measures_for_threshold(
pred: torch.Tensor, gt: torch.Tensor, threshold: float
) -> tuple[float, float, float, float, float]:
"""Calculates measures on one single sample, for a specific threshold. """Calculates measures on one single sample, for a specific threshold.
Parameters Parameters
---------- ----------
pred : torch.Tensor pred :
pixel-wise predictions Pixel-wise predictions.
gt : torch.Tensor
ground-truth (annotations)
threshold : float gt :
a particular threshold in which to calculate the performance Ground-truth (annotations).
measures
threshold :
A particular threshold in which to calculate the performance
measures.
Returns Returns
------- -------
precision: float precision : float
P, AKA positive predictive value (PPV). It corresponds arithmetically
recall: float to ``tp/(tp+fp)``. In the case ``tp+fp == 0``, this function returns
zero for precision.
specificity: float
recall : float
accuracy: float R, AKA sensitivity, hit rate, or true positive rate (TPR). It
corresponds arithmetically to ``tp/(tp+fn)``. In the special case
jaccard: float where ``tp+fn == 0``, this function returns zero for recall.
f1_score: float specificity : float
S, AKA selectivity or true negative rate (TNR). It
corresponds arithmetically to ``tn/(tn+fp)``. In the special case
where ``tn+fp == 0``, this function returns zero for specificity.
accuracy : float
A, see `Accuracy
<https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers>`_. is
the proportion of correct predictions (both true positives and true
negatives) among the total number of pixels examined. It corresponds
arithmetically to ``(tp+tn)/(tp+tn+fp+fn)``. This measure includes
both true-negatives and positives in the numerator, what makes it
sensitive to data or regions without annotations.
jaccard : float
J, see `Jaccard Index or Similarity
<https://en.wikipedia.org/wiki/Jaccard_index>`_. It corresponds
arithmetically to ``tp/(tp+fp+fn)``. In the special case where
``tn+fp+fn == 0``, this function returns zero for the Jaccard index.
The Jaccard index depends on a TP-only numerator, similarly to the F1
score. For regions where there are no annotations, the Jaccard index
will always be zero, irrespective of the model output. Accuracy may be
a better proxy if one needs to consider the true abscence of
annotations in a region as part of the measure.
f1_score : float
F1, see `F1-score <https://en.wikipedia.org/wiki/F1_score>`_. It
corresponds arithmetically to ``2*P*R/(P+R)`` or ``2*tp/(2*tp+fp+fn)``.
In the special case where ``P+R == (2*tp+fp+fn) == 0``, this function
returns zero for the Jaccard index. The F1 or Dice score depends on a
TP-only numerator, similarly to the Jaccard index. For regions where
there are no annotations, the F1-score will always be zero,
irrespective of the model output. Accuracy may be a better proxy if
one needs to consider the true abscence of annotations in a region as
part of the measure.
""" """
tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold) tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold)
# calc measures from scalars # calc measures from scalars
...@@ -117,12 +187,12 @@ def sample_measures_for_threshold(pred, gt, threshold): ...@@ -117,12 +187,12 @@ def sample_measures_for_threshold(pred, gt, threshold):
def run( def run(
dataset, dataset,
name, name: str,
predictions_folder, predictions_folder: str,
output_folder=None, output_folder: Optional[str | None] = None,
f1_thresh=None, f1_thresh: Optional[float] = None,
eer_thresh=None, eer_thresh: Optional[float] = None,
steps=1000, steps: Optional[int] = 1000,
): ):
"""Runs inference and calculates measures. """Runs inference and calculates measures.
...@@ -132,44 +202,46 @@ def run( ...@@ -132,44 +202,46 @@ def run(
dataset : py:class:`torch.utils.data.Dataset` dataset : py:class:`torch.utils.data.Dataset`
a dataset to iterate on a dataset to iterate on
name : str name:
the local name of this dataset (e.g. ``train``, or ``test``), to be the local name of this dataset (e.g. ``train``, or ``test``), to be
used when saving measures files. used when saving measures files.
predictions_folder : str predictions_folder:
folder where predictions for the dataset images has been previously folder where predictions for the dataset images has been previously
stored stored
output_folder : :py:class:`str`, Optional output_folder:
folder where to store results. folder where to store results.
f1_thresh : :py:class:`float`, Optional f1_thresh:
This number should come from This number should come from
the training set or a separate validation set. Using a test set value the training set or a separate validation set. Using a test set value
may bias your analysis. This number is also used to print the a priori may bias your analysis. This number is also used to print the a priori
F1-score on the evaluated set. F1-score on the evaluated set.
eer_thresh : :py:class:`float`, Optional eer_thresh:
This number should come from This number should come from
the training set or a separate validation set. Using a test set value the training set or a separate validation set. Using a test set value
may bias your analysis. This number is used to print the a priori may bias your analysis. This number is used to print the a priori
EER. EER.
steps : :py:class:`float`, Optional steps:
number of threshold steps to consider when evaluating thresholds. number of threshold steps to consider when evaluating thresholds.
Returns Returns
------- -------
f1_threshold : float maxf1_threshold : float
Threshold to achieve the highest possible F1-score for this dataset Threshold to achieve the highest possible F1-score for this dataset
eer_threshold : float post_eer_threshold : float
Threshold achieving Equal Error Rate for this dataset Threshold achieving Equal Error Rate for this dataset
""" """
predictions_path = os.path.join(
predictions_folder, f"predictions_{name}", "predictions.csv"
)
predictions_path = os.path.join(predictions_folder, name, "predictions.csv")
if not os.path.exists(predictions_path): if not os.path.exists(predictions_path):
predictions_path = predictions_folder predictions_path = predictions_folder
......
...@@ -2,42 +2,69 @@ ...@@ -2,42 +2,69 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
from typing import Union
import click import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from ..data.datamodule import CachingDataModule
from ..data.typing import DataLoader
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _validate_threshold(t, dataset): def _validate_threshold(
threshold: Union[int, float, str], dataloader_dict: dict[str, DataLoader]
):
"""Validates the user threshold selection. """Validates the user threshold selection.
Returns parsed threshold. Parameters
----------
threshold:
This number is used to define positives and negatives from
probability maps, and report F1-scores (a priori). It
should either come from the training set or a separate validation set
to avoid biasing the analysis. Optionally, if you provide a multi-set
dataset as input, this may also be the name of an existing set from
which the threshold will be estimated (highest F1-score) and then
applied to the subsequent sets. This number is also used to print
the test set F1-score a priori performance
dataloader_dict:
Dictionary of set_name: dataloader, there set_name is the name of a dataset split
and dataloader is the torch dataloader for that split.
Returns
-------
The parsed threshold.
""" """
if t is None: if threshold is None:
return 0.5 return 0.5
try: try:
# we try to convert it to float first # we try to convert it to float first
t = float(t) threshold = float(threshold)
if t < 0.0 or t > 1.0: if threshold < 0.0 or threshold > 1.0:
raise ValueError("Float thresholds must be within range [0.0, 1.0]") raise ValueError("Float thresholds must be within range [0.0, 1.0]")
except ValueError: except ValueError:
# it is a bit of text - assert dataset with name is available # it is a bit of text - assert dataset with name is available
if not isinstance(dataset, dict): if not isinstance(dataloader_dict, dict):
raise ValueError( raise ValueError(
"Threshold should be a floating-point number " "Threshold should be a floating-point number "
"if your provide only a single dataset for evaluation" "if your provide only a single dataset for evaluation"
) )
if t not in dataset: if threshold not in dataloader_dict:
raise ValueError( raise ValueError(
f"Text thresholds should match dataset names, " f"Text thresholds should match dataset names, "
f"but {t} is not available among the datasets provided (" f"but {threshold} is not available among the datasets provided ("
f"({', '.join(dataset.keys())})" f"({', '.join(dataloader_dict.keys())})"
) )
return t return threshold
@click.command( @click.command(
...@@ -71,13 +98,9 @@ def _validate_threshold(t, dataset): ...@@ -71,13 +98,9 @@ def _validate_threshold(t, dataset):
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--dataset", "--datamodule",
"-d", "-d",
help="A torch.utils.data.dataset.Dataset instance implementing a dataset " help="A lighting data module containing the training and validation sets.",
"to be used for evaluation purposes, possibly including all pre-processing "
"pipelines required or, optionally, a dictionary mapping string keys to "
"torch.utils.data.dataset.Dataset instances. All keys that do not start "
"with an underscore (_) will be processed.",
required=True, required=True,
cls=ResourceOption, cls=ResourceOption,
) )
...@@ -109,11 +132,11 @@ def _validate_threshold(t, dataset): ...@@ -109,11 +132,11 @@ def _validate_threshold(t, dataset):
) )
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def evaluate( def evaluate(
output_folder, output_folder: str,
predictions_folder, predictions_folder: str,
dataset, datamodule: CachingDataModule,
threshold, threshold: Union[int, float, str],
steps, steps: int,
**_, **_,
) -> None: ) -> None:
"""Evaluates a CNN on a tuberculosis prediction task. """Evaluates a CNN on a tuberculosis prediction task.
...@@ -123,17 +146,22 @@ def evaluate( ...@@ -123,17 +146,22 @@ def evaluate(
from ..engine.evaluator import run from ..engine.evaluator import run
threshold = _validate_threshold(threshold, dataset) datamodule.prepare_data()
datamodule.setup(stage="test")
datamodule.set_chunk_size(1, 1)
if not isinstance(dataset, dict): dataloader = datamodule.test_dataloader()
dataset = {"test": dataset}
threshold = _validate_threshold(threshold, dataloader)
if isinstance(threshold, str): if isinstance(threshold, str):
# first run evaluation for reference dataset # first run evaluation for reference dataset
logger.info(f"Evaluating threshold on '{threshold}' set") logger.info(f"Evaluating threshold on '{threshold}' set")
f1_threshold, eer_threshold = run( f1_threshold, eer_threshold = run(
dataset[threshold], threshold, predictions_folder, steps=steps _, threshold, predictions_folder, steps=steps
) )
if (f1_threshold is not None) and (eer_threshold is not None): if (f1_threshold is not None) and (eer_threshold is not None):
logger.info(f"Set --f1_threshold={f1_threshold:.5f}") logger.info(f"Set --f1_threshold={f1_threshold:.5f}")
logger.info(f"Set --eer_threshold={eer_threshold:.5f}") logger.info(f"Set --eer_threshold={eer_threshold:.5f}")
...@@ -144,12 +172,11 @@ def evaluate( ...@@ -144,12 +172,11 @@ def evaluate(
else: else:
raise ValueError("Threshold value is neither an int nor a float") raise ValueError("Threshold value is neither an int nor a float")
# now run with the for k, v in dataloader.items():
for k, v in dataset.items():
if k.startswith("_"): if k.startswith("_"):
logger.info(f"Skipping dataset '{k}' (not to be evaluated)") logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
continue continue
logger.info(f"Analyzing '{k}' set...") logger.info(f"Analyzing '{threshold}' set...")
run( run(
v, v,
k, k,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment