Skip to content
Snippets Groups Projects
Commit 1b1382bb authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Evaluation script saves more plots, combines results

parent fc3551d9
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -186,10 +186,8 @@ def sample_measures_for_threshold( ...@@ -186,10 +186,8 @@ def sample_measures_for_threshold(
def run( def run(
dataset,
name: str, name: str,
predictions_folder: str, predictions_folder: str,
output_folder: Optional[str | None] = None,
f1_thresh: Optional[float] = None, f1_thresh: Optional[float] = None,
eer_thresh: Optional[float] = None, eer_thresh: Optional[float] = None,
steps: Optional[int] = 1000, steps: Optional[int] = 1000,
...@@ -199,9 +197,6 @@ def run( ...@@ -199,9 +197,6 @@ def run(
Parameters Parameters
--------- ---------
dataset : py:class:`torch.utils.data.Dataset`
a dataset to iterate on
name: name:
the local name of this dataset (e.g. ``train``, or ``test``), to be the local name of this dataset (e.g. ``train``, or ``test``), to be
used when saving measures files. used when saving measures files.
...@@ -210,9 +205,6 @@ def run( ...@@ -210,9 +205,6 @@ def run(
folder where predictions for the dataset images has been previously folder where predictions for the dataset images has been previously
stored stored
output_folder:
folder where to store results.
f1_thresh: f1_thresh:
This number should come from This number should come from
the training set or a separate validation set. Using a test set value the training set or a separate validation set. Using a test set value
...@@ -238,9 +230,7 @@ def run( ...@@ -238,9 +230,7 @@ def run(
post_eer_threshold : float post_eer_threshold : float
Threshold achieving Equal Error Rate for this dataset Threshold achieving Equal Error Rate for this dataset
""" """
predictions_path = os.path.join( predictions_path = os.path.join(predictions_folder, f"{name}.csv")
predictions_folder, f"predictions_{name}", "predictions.csv"
)
if not os.path.exists(predictions_path): if not os.path.exists(predictions_path):
predictions_path = predictions_folder predictions_path = predictions_folder
...@@ -298,12 +288,12 @@ def run( ...@@ -298,12 +288,12 @@ def run(
) )
data_df = data_df.set_index("index") data_df = data_df.set_index("index")
# Save evaluation csv """# Save evaluation csv
if output_folder is not None: if output_folder is not None:
fullpath = os.path.join(output_folder, f"{name}.csv") fullpath = os.path.join(output_folder, f"{name}.csv")
logger.info(f"Saving {fullpath}...") logger.info(f"Saving {fullpath}...")
os.makedirs(os.path.dirname(fullpath), exist_ok=True) os.makedirs(os.path.dirname(fullpath), exist_ok=True)
data_df.to_csv(fullpath) data_df.to_csv(fullpath)"""
# Find max F1 score # Find max F1 score
f1_scores = numpy.asarray(data_df["f1_score"]) f1_scores = numpy.asarray(data_df["f1_score"])
...@@ -328,42 +318,38 @@ def run( ...@@ -328,42 +318,38 @@ def run(
f"threshold {post_eer_threshold:.3f} (chosen *a posteriori*)" f"threshold {post_eer_threshold:.3f} (chosen *a posteriori*)"
) )
# Save score table # Generate scores fig
if output_folder is not None: fig_score, axes = plt.subplots(1)
fig, axes = plt.subplots(1) fig_score.tight_layout(pad=3.0)
fig.tight_layout(pad=3.0)
# Names and bounds # Names and bounds
axes.set_xlabel("Score") axes.set_xlabel("Score")
axes.set_ylabel("Normalized counts") axes.set_ylabel("Normalized counts")
axes.set_xlim(0.0, 1.0) axes.set_xlim(0.0, 1.0)
neg_weights = numpy.ones_like(neg_gt["likelihood"]) / len( neg_weights = numpy.ones_like(neg_gt["likelihood"]) / len(
pred_data["likelihood"] pred_data["likelihood"]
) )
pos_weights = numpy.ones_like(pos_gt["likelihood"]) / len( pos_weights = numpy.ones_like(pos_gt["likelihood"]) / len(
pred_data["likelihood"] pred_data["likelihood"]
) )
axes.hist(
[neg_gt["likelihood"], pos_gt["likelihood"]],
weights=[neg_weights, pos_weights],
bins=100,
color=["tab:blue", "tab:orange"],
label=["Negatives", "Positives"],
)
axes.legend(prop={"size": 10}, loc="upper center")
axes.set_title(f"Score table for {name} subset")
# we should see some of axes 1 axes axes.hist(
axes.spines["right"].set_visible(False) [neg_gt["likelihood"], pos_gt["likelihood"]],
axes.spines["top"].set_visible(False) weights=[neg_weights, pos_weights],
axes.spines["left"].set_position(("data", -0.015)) bins=100,
color=["tab:blue", "tab:orange"],
label=["Negatives", "Positives"],
)
axes.legend(prop={"size": 10}, loc="upper center")
axes.set_title(f"Score table for {name} subset")
fullpath = os.path.join(output_folder, f"{name}_score_table.pdf") # we should see some of axes 1 axes
fig.savefig(fullpath) axes.spines["right"].set_visible(False)
axes.spines["top"].set_visible(False)
axes.spines["left"].set_position(("data", -0.015))
if f1_thresh is not None and eer_thresh is not None: """if f1_thresh is not None and eer_thresh is not None:
# get the closest possible threshold we have # get the closest possible threshold we have
index = int(round(steps * f1_thresh)) index = int(round(steps * f1_thresh))
f1_a_priori = data_df["f1_score"][index] f1_a_priori = data_df["f1_score"][index]
...@@ -375,6 +361,6 @@ def run( ...@@ -375,6 +361,6 @@ def run(
) )
# Print the a priori EER threshold # Print the a priori EER threshold
logger.info(f"Equal error rate (chosen *a priori*) {eer_thresh:.3f}") logger.info(f"Equal error rate (chosen *a priori*) {eer_thresh:.3f}")"""
return maxf1_threshold, post_eer_threshold return pred_data, fig_score, maxf1_threshold, post_eer_threshold
...@@ -2,15 +2,21 @@ ...@@ -2,15 +2,21 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import os
from collections import defaultdict
from typing import Union from typing import Union
import click import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from matplotlib.backends.backend_pdf import PdfPages
from ..data.datamodule import CachingDataModule from ..data.datamodule import CachingDataModule
from ..data.typing import DataLoader from ..data.typing import DataLoader
from ..utils.plot import precision_recall_f1iso, roc_curve
from ..utils.table import performance_table
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
...@@ -117,7 +123,7 @@ def _validate_threshold( ...@@ -117,7 +123,7 @@ def _validate_threshold(
"the test set F1-score a priori performance", "the test set F1-score a priori performance",
default=None, default=None,
show_default=False, show_default=False,
required=False, required=True,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
...@@ -159,8 +165,10 @@ def evaluate( ...@@ -159,8 +165,10 @@ def evaluate(
if isinstance(threshold, str): if isinstance(threshold, str):
# first run evaluation for reference dataset # first run evaluation for reference dataset
logger.info(f"Evaluating threshold on '{threshold}' set") logger.info(f"Evaluating threshold on '{threshold}' set")
f1_threshold, eer_threshold = run( _, _, f1_threshold, eer_threshold = run(
_, threshold, predictions_folder, steps=steps name=threshold,
predictions_folder=predictions_folder,
steps=steps,
) )
if (f1_threshold is not None) and (eer_threshold is not None): if (f1_threshold is not None) and (eer_threshold is not None):
...@@ -173,17 +181,72 @@ def evaluate( ...@@ -173,17 +181,72 @@ def evaluate(
else: else:
raise ValueError("Threshold value is neither an int nor a float") raise ValueError("Threshold value is neither an int nor a float")
for k, v in dataloader.items(): results_dict = { # type: ignore
"pred_data": defaultdict(dict),
"fig_score": defaultdict(dict),
"maxf1_threshold": defaultdict(dict),
"post_eer_threshold": defaultdict(dict),
}
for k in dataloader.keys():
if k.startswith("_"): if k.startswith("_"):
logger.info(f"Skipping dataset '{k}' (not to be evaluated)") logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
continue continue
logger.info(f"Analyzing '{k}' set...") logger.info(f"Analyzing '{k}' set...")
run( pred_data, fig_score, maxf1_threshold, post_eer_threshold = run(
v,
k, k,
predictions_folder, predictions_folder,
output_folder,
f1_thresh=f1_threshold, f1_thresh=f1_threshold,
eer_thresh=eer_threshold, eer_thresh=eer_threshold,
steps=steps, steps=steps,
) )
results_dict["pred_data"][k] = pred_data
results_dict["fig_score"][k] = fig_score
results_dict["maxf1_threshold"][k] = maxf1_threshold
results_dict["post_eer_threshold"][k] = post_eer_threshold
if output_folder is not None:
output_scores = os.path.join(output_folder, "scores.pdf")
if output_scores is not None:
output_scores = os.path.realpath(output_scores)
logger.info(f"Creating and saving scores at {output_scores}...")
os.makedirs(os.path.dirname(output_scores), exist_ok=True)
score_pdf = PdfPages(output_scores)
for fig in results_dict["fig_score"].values():
score_pdf.savefig(fig)
score_pdf.close()
data = {}
for subset_name in dataloader.keys():
data[subset_name] = {
"df": results_dict["pred_data"][subset_name],
"threshold": results_dict["post_eer_threshold"][ # type: ignore
threshold
].item(),
}
output_figure = os.path.join(output_folder, "plots.pdf")
if output_figure is not None:
output_figure = os.path.realpath(output_figure)
logger.info(f"Creating and saving plots at {output_figure}...")
os.makedirs(os.path.dirname(output_figure), exist_ok=True)
pdf = PdfPages(output_figure)
pdf.savefig(precision_recall_f1iso(data))
pdf.savefig(roc_curve(data))
pdf.close()
output_table = os.path.join(output_folder, "table.txt")
logger.info("Tabulating performance summary...")
table = performance_table(data, "rst")
click.echo(table)
if output_table is not None:
output_table = os.path.realpath(output_table)
logger.info(f"Saving table at {output_table}...")
os.makedirs(os.path.dirname(output_table), exist_ok=True)
with open(output_table, "w") as f:
f.write(table)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment