Skip to content
Snippets Groups Projects
Commit 40985094 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[script.compare] Implement system performance tabulation

parent a6588733
No related branches found
No related tags found
1 merge request!12Streamlining
......@@ -9,8 +9,10 @@ from bob.extension.scripts.click_helper import (
)
import pandas
import tabulate
from ..utils.plot import precision_recall_f1iso
from ..utils.table import performance_table
import logging
logger = logging.getLogger(__name__)
......@@ -44,7 +46,7 @@ def _validate_threshold(t, dataset):
return t
def _load_and_plot(data, threshold=None):
def _load(data, threshold=None):
"""Plots comparison chart of all evaluated models
Parameters
......@@ -55,20 +57,26 @@ def _load_and_plot(data, threshold=None):
paths to ``metrics.csv`` style files.
threshold : :py:class:`float`, :py:class:`str`, Optional
A value indicating which threshold to choose for plotting a "F1-score"
(black) dot on the various curves. If set to ``None``, then plot the
maximum F1-score on that curve. If set to a floating-point value, then
plot the F1-score that is obtained on that particular threshold. If
set to a string, it should match one of the keys in ``data``. It then
first calculate the threshold reaching the maximum F1-score on that
particular dataset and then applies that threshold to all other sets.
A value indicating which threshold to choose for selecting a "F1-score"
If set to ``None``, then use the maximum F1-score on that metrics file.
If set to a floating-point value, then use the F1-score that is
obtained on that particular threshold. If set to a string, it should
match one of the keys in ``data``. It then first calculate the
threshold reaching the maximum F1-score on that particular dataset and
then applies that threshold to all other sets.
Returns
-------
figure : matplotlib.figure.Figure
A figure, with all systems combined into a single plot.
data : dict
A dict in which keys are the names of the systems and the values are
dictionaries that contain two keys:
* ``df``: A :py:class:`pandas.DataFrame` with the metrics data loaded
to
* ``threshold``: A threshold to be used for summarization, depending on
the ``threshold`` parameter set on the input
"""
......@@ -78,8 +86,8 @@ def _load_and_plot(data, threshold=None):
metrics_path = data[threshold]
df = pandas.read_csv(metrics_path)
maxf1 = df["f1_score"].max()
use_threshold = df["threshold"][df["f1_score"].idxmax()]
maxf1 = df.f1_score.max()
use_threshold = df.threshold[df.f1_score.idxmax()]
logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'")
elif isinstance(threshold, float):
......@@ -91,20 +99,19 @@ def _load_and_plot(data, threshold=None):
thresholds = []
# loads all data
retval = {}
for name, metrics_path in data.items():
logger.info(f"Loading metrics from {metrics_path}...")
df = pandas.read_csv(metrics_path)
if threshold is None:
use_threshold = df["threshold"][df["f1_score"].idxmax()]
use_threshold = df.threshold[df.f1_score.idxmax()]
logger.info(f"Dataset '{name}': threshold = {use_threshold:.3f}'")
names.append(name)
dfs.append(df)
thresholds.append(use_threshold)
retval[name] = dict(df=df, threshold=use_threshold)
return precision_recall_f1iso(names, dfs, thresholds, confidence=True)
return retval
@click.command(
......@@ -121,15 +128,32 @@ def _load_and_plot(data, threshold=None):
nargs=-1,
)
@click.option(
"--output",
"-o",
help="Path where write the output figure (PDF format)",
"--output-figure",
"-f",
help="Path where write the output figure (any extension supported by "
"matplotlib is possible). If not provided, does not produce a figure.",
required=False,
default=None,
type=click.Path(dir_okay=False, file_okay=True),
)
@click.option(
"--table-format",
"-T",
help="The format to use for the comparison table",
show_default=True,
required=True,
default="comparison.pdf",
type=click.Path(),
default="rst",
type=click.Choice(tabulate.tabulate_formats),
)
@click.option(
"--output-table",
"-u",
help="Path where write the output table. If not provided, does not write "
"write a table to file, only to stdout.",
required=False,
default=None,
type=click.Path(dir_okay=False, file_okay=True),
)
@click.option(
"--threshold",
"-t",
......@@ -145,7 +169,8 @@ def _load_and_plot(data, threshold=None):
required=False,
)
@verbosity_option()
def compare(label_path, output, threshold, **kwargs):
def compare(label_path, output_figure, table_format, output_table, threshold,
**kwargs):
"""Compares multiple systems together"""
# hack to get a dictionary from arguments passed to input
......@@ -156,9 +181,18 @@ def compare(label_path, output, threshold, **kwargs):
threshold = _validate_threshold(threshold, data)
fig = _load_and_plot(data, threshold=threshold)
logger.info(f"Saving plot at {output}")
fig.savefig(output)
# TODO: print table with all results
pass
# load all data metrics
data = _load(data, threshold=threshold)
if output_figure is not None:
logger.info(f"Creating and saving plot at {output_figure}...")
fig = precision_recall_f1iso(data, confidence=True)
fig.savefig(output_figure)
logger.info("Tabulating performance summary...")
table = performance_table(data, table_format)
click.echo(table)
if output_table is not None:
logger.info(f"Saving table at {output_table}...")
with open(output_table, "wt") as f:
f.write(table)
......@@ -412,7 +412,11 @@ def experiment(
continue
systems += [f"{k} (2nd. annot.)", os.path.join(analysis_folder, k,
"metrics-second-annotator.csv")]
output_pdf = os.path.join(output_folder, "comparison.pdf")
ctx.invoke(compare, label_path=systems, output=output_pdf, verbose=verbose)
output_figure = os.path.join(output_folder, "comparison.pdf")
output_table = os.path.join(output_folder, "comparison.rst")
ctx.invoke(compare, label_path=systems, output_figure=output_figure,
output_table=output_table, verbose=verbose)
logger.info("Ended comparison, and the experiment - bye.")
......@@ -158,6 +158,7 @@ def _check_experiment_stare(overlay):
# check outcomes of the comparison phase
assert os.path.exists(os.path.join(output_folder, "comparison.pdf"))
assert os.path.exists(os.path.join(output_folder, "comparison.rst"))
keywords = {
r"^Started training$": 1,
......@@ -175,6 +176,9 @@ def _check_experiment_stare(overlay):
r"^Ended evaluation$": 1,
r"^Started comparison$": 1,
r"^Loading metrics from": 4,
r"^Creating and saving plot at": 1,
r"^Tabulating performance summary...": 1,
r"^Saving table at": 1,
r"^Ended comparison.*$": 1,
}
buf.seek(0)
......@@ -399,14 +403,20 @@ def _check_compare(runner):
os.path.join(output_folder, "metrics.csv"),
"test (2nd. human)",
os.path.join(output_folder, "metrics-second-annotator.csv"),
"--output-figure=comparison.pdf",
"--output-table=comparison.rst",
],
)
_assert_exit_0(result)
assert os.path.exists("comparison.pdf")
assert os.path.exists("comparison.rst")
keywords = {
r"^Loading metrics from": 2,
r"^Creating and saving plot at": 1,
r"^Tabulating performance summary...": 1,
r"^Saving table at": 1,
}
buf.seek(0)
logging_output = buf.read()
......
......@@ -97,35 +97,41 @@ def _precision_recall_canvas(title=None):
plt.tight_layout()
def precision_recall_f1iso(label, df, threshold, confidence=True):
def precision_recall_f1iso(data, confidence=True):
"""Creates a precision-recall plot with confidence intervals
This function creates and returns a Matplotlib figure with a
precision-recall plot containing shaded confidence intervals. The plot
will be annotated with F1-score iso-lines (in which the F1-score maintains
the same value).
precision-recall plot containing shaded confidence intervals (standard
deviation on the precision-recall measurements). The plot will be
annotated with F1-score iso-lines (in which the F1-score maintains the same
value).
This function specially supports "second-annotator" entries by plotting a
line showing the comparison between the default annotator being analyzed
and a second "opinion". Second annotator dataframes contain a single
entry (threshold=0.5), given the nature of the binary map comparisons.
Parameters
----------
label : :py:class:`list`
A list of names to be associated to each line
data : dict
A dictionary in which keys are strings defining plot labels and values
are dictionaries with two entries:
df : :py:class:`pandas.DataFrame`
A dataframe that is produced by our evaluator engine, indexed by
integer "thresholds", containing the following columns: ``threshold``
(sorted ascending), ``precision``, ``recall``, ``pr_upper`` (upper
precision bounds), ``pr_lower`` (lower precision bounds), ``re_upper``
(upper recall bounds), ``re_lower`` (lower recall bounds).
* ``df``: :py:class:`pandas.DataFrame`
Dataframes with a single entry are treated specially as these are
considered "second-annotator" performances. A single dot and a line
showing the variability is drawn in these cases.
A dataframe that is produced by our evaluator engine, indexed by
integer "thresholds", containing the following columns: ``threshold``
(sorted ascending), ``precision``, ``recall``, ``pr_upper`` (upper
precision bounds), ``pr_lower`` (lower precision bounds),
``re_upper`` (upper recall bounds), ``re_lower`` (lower recall
bounds).
threshold : :py:class:`list`
A list of thresholds to graph with a dot for each set. Specific
threshold values do not affect "second-annotator" dataframes.
* ``threshold``: :py:class:`list`
A threshold to graph with a dot for each set. Specific
threshold values do not affect "second-annotator" dataframes.
confidence : :py:class:`bool`, Optional
If set, draw confidence intervals for each line, using ``*_upper`` and
......@@ -160,32 +166,35 @@ def precision_recall_f1iso(label, df, threshold, confidence=True):
legend = []
for kn, kdf, kt in zip(label, df, threshold):
for name, value in data.items():
df = value["df"]
threshold = value["threshold"]
# plots only from the point where recall reaches its maximum,
# otherwise, we don't see a curve...
max_recall = kdf["recall"].idxmax()
pi = kdf.precision[max_recall:]
ri = kdf.recall[max_recall:]
max_recall = df["recall"].idxmax()
pi = df.precision[max_recall:]
ri = df.recall[max_recall:]
valid = (pi + ri) > 0
f1 = 2 * (pi[valid] * ri[valid]) / (pi[valid] + ri[valid])
# optimal point along the curve
bins = len(kdf)
index = int(round(bins*kt))
index = min(index, len(kdf)-1) #avoids out of range indexing
bins = len(df)
index = int(round(bins*threshold))
index = min(index, len(df)-1) #avoids out of range indexing
# plots Recall/Precision as threshold changes
label = f"{kn} (F1={kdf.f1_score[index]:.4f})"
label = f"{name} (F1={df.f1_score[index]:.4f})"
color = next(colorcycler)
if len(kdf) == 1:
if len(df) == 1:
# plot black dot for F1-score at select threshold
marker, = axes.plot(kdf.recall[index], kdf.precision[index],
marker, = axes.plot(df.recall[index], df.precision[index],
marker="*", markersize=6, color=color, alpha=0.8,
linestyle="None")
line, = axes.plot(kdf.recall[index], kdf.precision[index],
line, = axes.plot(df.recall[index], df.precision[index],
linestyle="None", color=color, alpha=0.2)
legend.append(([marker, line], label))
else:
......@@ -193,17 +202,17 @@ def precision_recall_f1iso(label, df, threshold, confidence=True):
style = next(linecycler)
line, = axes.plot(ri[pi > 0], pi[pi > 0], color=color,
linestyle=style)
marker, = axes.plot(kdf.recall[index], kdf.precision[index],
marker, = axes.plot(df.recall[index], df.precision[index],
marker="o", linestyle=style, markersize=4,
color=color, alpha=0.8)
legend.append(([marker, line], label))
if confidence:
pui = kdf.pr_upper[max_recall:]
pli = kdf.pr_lower[max_recall:]
rui = kdf.re_upper[max_recall:]
rli = kdf.re_lower[max_recall:]
pui = df.pr_upper[max_recall:]
pli = df.pr_lower[max_recall:]
rui = df.re_upper[max_recall:]
rli = df.re_lower[max_recall:]
# Plot confidence
# Upper bound
......@@ -212,7 +221,7 @@ def precision_recall_f1iso(label, df, threshold, confidence=True):
vert_y = numpy.concatenate((pui[pui > 0], pli[pli > 0][::-1]))
# hacky workaround to plot 2nd human
if len(kdf) == 1: #binary system, very likely
if len(df) == 1: #binary system, very likely
logger.warning("Found 2nd human annotator - patching...")
p, = axes.plot(vert_x, vert_y, color=color, alpha=0.1, lw=3)
else:
......
......@@ -9,7 +9,7 @@ from torch.nn.modules.module import _addindent
def summary(model):
"""Counts the number of paramters in each model layer
"""Counts the number of parameters in each model layer
Parameters
----------
......
......@@ -88,6 +88,7 @@ Toolbox
bob.ip.binseg.utils.model_serialization
bob.ip.binseg.utils.model_zoo
bob.ip.binseg.utils.plot
bob.ip.binseg.utils.table
bob.ip.binseg.utils.summary
......
......@@ -7,6 +7,7 @@
.. _installation: https://www.idiap.ch/software/bob/install
.. _mailing list: https://www.idiap.ch/software/bob/discuss
.. _pytorch: https://pytorch.org
.. _tabulate: https://pypi.org/project/tabulate/
.. _our paper: https://arxiv.org/abs/1909.03856
.. Raw data websites
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment