From 40985094e9251eec30d63f495a19de2e5f2e7530 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Fri, 24 Apr 2020 15:53:16 +0200 Subject: [PATCH] [script.compare] Implement system performance tabulation --- bob/ip/binseg/script/compare.py | 94 ++++++++++++++++++++---------- bob/ip/binseg/script/experiment.py | 8 ++- bob/ip/binseg/test/test_cli.py | 10 ++++ bob/ip/binseg/utils/plot.py | 79 ++++++++++++++----------- bob/ip/binseg/utils/summary.py | 2 +- doc/api.rst | 1 + doc/links.rst | 1 + 7 files changed, 127 insertions(+), 68 deletions(-) diff --git a/bob/ip/binseg/script/compare.py b/bob/ip/binseg/script/compare.py index e27b6296..c3b8d715 100644 --- a/bob/ip/binseg/script/compare.py +++ b/bob/ip/binseg/script/compare.py @@ -9,8 +9,10 @@ from bob.extension.scripts.click_helper import ( ) import pandas +import tabulate from ..utils.plot import precision_recall_f1iso +from ..utils.table import performance_table import logging logger = logging.getLogger(__name__) @@ -44,7 +46,7 @@ def _validate_threshold(t, dataset): return t -def _load_and_plot(data, threshold=None): +def _load(data, threshold=None): """Plots comparison chart of all evaluated models Parameters @@ -55,20 +57,26 @@ def _load_and_plot(data, threshold=None): paths to ``metrics.csv`` style files. threshold : :py:class:`float`, :py:class:`str`, Optional - A value indicating which threshold to choose for plotting a "F1-score" - (black) dot on the various curves. If set to ``None``, then plot the - maximum F1-score on that curve. If set to a floating-point value, then - plot the F1-score that is obtained on that particular threshold. If - set to a string, it should match one of the keys in ``data``. It then - first calculate the threshold reaching the maximum F1-score on that - particular dataset and then applies that threshold to all other sets. + A value indicating which threshold to choose for selecting a "F1-score" + If set to ``None``, then use the maximum F1-score on that metrics file. + If set to a floating-point value, then use the F1-score that is + obtained on that particular threshold. If set to a string, it should + match one of the keys in ``data``. It then first calculate the + threshold reaching the maximum F1-score on that particular dataset and + then applies that threshold to all other sets. Returns ------- - figure : matplotlib.figure.Figure - A figure, with all systems combined into a single plot. + data : dict + A dict in which keys are the names of the systems and the values are + dictionaries that contain two keys: + + * ``df``: A :py:class:`pandas.DataFrame` with the metrics data loaded + to + * ``threshold``: A threshold to be used for summarization, depending on + the ``threshold`` parameter set on the input """ @@ -78,8 +86,8 @@ def _load_and_plot(data, threshold=None): metrics_path = data[threshold] df = pandas.read_csv(metrics_path) - maxf1 = df["f1_score"].max() - use_threshold = df["threshold"][df["f1_score"].idxmax()] + maxf1 = df.f1_score.max() + use_threshold = df.threshold[df.f1_score.idxmax()] logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'") elif isinstance(threshold, float): @@ -91,20 +99,19 @@ def _load_and_plot(data, threshold=None): thresholds = [] # loads all data + retval = {} for name, metrics_path in data.items(): logger.info(f"Loading metrics from {metrics_path}...") df = pandas.read_csv(metrics_path) if threshold is None: - use_threshold = df["threshold"][df["f1_score"].idxmax()] + use_threshold = df.threshold[df.f1_score.idxmax()] logger.info(f"Dataset '{name}': threshold = {use_threshold:.3f}'") - names.append(name) - dfs.append(df) - thresholds.append(use_threshold) + retval[name] = dict(df=df, threshold=use_threshold) - return precision_recall_f1iso(names, dfs, thresholds, confidence=True) + return retval @click.command( @@ -121,15 +128,32 @@ def _load_and_plot(data, threshold=None): nargs=-1, ) @click.option( - "--output", - "-o", - help="Path where write the output figure (PDF format)", + "--output-figure", + "-f", + help="Path where write the output figure (any extension supported by " + "matplotlib is possible). If not provided, does not produce a figure.", + required=False, + default=None, + type=click.Path(dir_okay=False, file_okay=True), +) +@click.option( + "--table-format", + "-T", + help="The format to use for the comparison table", show_default=True, required=True, - default="comparison.pdf", - type=click.Path(), + default="rst", + type=click.Choice(tabulate.tabulate_formats), +) +@click.option( + "--output-table", + "-u", + help="Path where write the output table. If not provided, does not write " + "write a table to file, only to stdout.", + required=False, + default=None, + type=click.Path(dir_okay=False, file_okay=True), ) - @click.option( "--threshold", "-t", @@ -145,7 +169,8 @@ def _load_and_plot(data, threshold=None): required=False, ) @verbosity_option() -def compare(label_path, output, threshold, **kwargs): +def compare(label_path, output_figure, table_format, output_table, threshold, + **kwargs): """Compares multiple systems together""" # hack to get a dictionary from arguments passed to input @@ -156,9 +181,18 @@ def compare(label_path, output, threshold, **kwargs): threshold = _validate_threshold(threshold, data) - fig = _load_and_plot(data, threshold=threshold) - logger.info(f"Saving plot at {output}") - fig.savefig(output) - - # TODO: print table with all results - pass + # load all data metrics + data = _load(data, threshold=threshold) + + if output_figure is not None: + logger.info(f"Creating and saving plot at {output_figure}...") + fig = precision_recall_f1iso(data, confidence=True) + fig.savefig(output_figure) + + logger.info("Tabulating performance summary...") + table = performance_table(data, table_format) + click.echo(table) + if output_table is not None: + logger.info(f"Saving table at {output_table}...") + with open(output_table, "wt") as f: + f.write(table) diff --git a/bob/ip/binseg/script/experiment.py b/bob/ip/binseg/script/experiment.py index 54580368..e15f8588 100644 --- a/bob/ip/binseg/script/experiment.py +++ b/bob/ip/binseg/script/experiment.py @@ -412,7 +412,11 @@ def experiment( continue systems += [f"{k} (2nd. annot.)", os.path.join(analysis_folder, k, "metrics-second-annotator.csv")] - output_pdf = os.path.join(output_folder, "comparison.pdf") - ctx.invoke(compare, label_path=systems, output=output_pdf, verbose=verbose) + + output_figure = os.path.join(output_folder, "comparison.pdf") + output_table = os.path.join(output_folder, "comparison.rst") + + ctx.invoke(compare, label_path=systems, output_figure=output_figure, + output_table=output_table, verbose=verbose) logger.info("Ended comparison, and the experiment - bye.") diff --git a/bob/ip/binseg/test/test_cli.py b/bob/ip/binseg/test/test_cli.py index 6e2d8be2..8acd7f87 100644 --- a/bob/ip/binseg/test/test_cli.py +++ b/bob/ip/binseg/test/test_cli.py @@ -158,6 +158,7 @@ def _check_experiment_stare(overlay): # check outcomes of the comparison phase assert os.path.exists(os.path.join(output_folder, "comparison.pdf")) + assert os.path.exists(os.path.join(output_folder, "comparison.rst")) keywords = { r"^Started training$": 1, @@ -175,6 +176,9 @@ def _check_experiment_stare(overlay): r"^Ended evaluation$": 1, r"^Started comparison$": 1, r"^Loading metrics from": 4, + r"^Creating and saving plot at": 1, + r"^Tabulating performance summary...": 1, + r"^Saving table at": 1, r"^Ended comparison.*$": 1, } buf.seek(0) @@ -399,14 +403,20 @@ def _check_compare(runner): os.path.join(output_folder, "metrics.csv"), "test (2nd. human)", os.path.join(output_folder, "metrics-second-annotator.csv"), + "--output-figure=comparison.pdf", + "--output-table=comparison.rst", ], ) _assert_exit_0(result) assert os.path.exists("comparison.pdf") + assert os.path.exists("comparison.rst") keywords = { r"^Loading metrics from": 2, + r"^Creating and saving plot at": 1, + r"^Tabulating performance summary...": 1, + r"^Saving table at": 1, } buf.seek(0) logging_output = buf.read() diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py index 0bbea34a..9be1b0ed 100644 --- a/bob/ip/binseg/utils/plot.py +++ b/bob/ip/binseg/utils/plot.py @@ -97,35 +97,41 @@ def _precision_recall_canvas(title=None): plt.tight_layout() -def precision_recall_f1iso(label, df, threshold, confidence=True): +def precision_recall_f1iso(data, confidence=True): """Creates a precision-recall plot with confidence intervals This function creates and returns a Matplotlib figure with a - precision-recall plot containing shaded confidence intervals. The plot - will be annotated with F1-score iso-lines (in which the F1-score maintains - the same value). + precision-recall plot containing shaded confidence intervals (standard + deviation on the precision-recall measurements). The plot will be + annotated with F1-score iso-lines (in which the F1-score maintains the same + value). + + This function specially supports "second-annotator" entries by plotting a + line showing the comparison between the default annotator being analyzed + and a second "opinion". Second annotator dataframes contain a single + entry (threshold=0.5), given the nature of the binary map comparisons. Parameters ---------- - label : :py:class:`list` - A list of names to be associated to each line + data : dict + A dictionary in which keys are strings defining plot labels and values + are dictionaries with two entries: - df : :py:class:`pandas.DataFrame` - A dataframe that is produced by our evaluator engine, indexed by - integer "thresholds", containing the following columns: ``threshold`` - (sorted ascending), ``precision``, ``recall``, ``pr_upper`` (upper - precision bounds), ``pr_lower`` (lower precision bounds), ``re_upper`` - (upper recall bounds), ``re_lower`` (lower recall bounds). + * ``df``: :py:class:`pandas.DataFrame` - Dataframes with a single entry are treated specially as these are - considered "second-annotator" performances. A single dot and a line - showing the variability is drawn in these cases. + A dataframe that is produced by our evaluator engine, indexed by + integer "thresholds", containing the following columns: ``threshold`` + (sorted ascending), ``precision``, ``recall``, ``pr_upper`` (upper + precision bounds), ``pr_lower`` (lower precision bounds), + ``re_upper`` (upper recall bounds), ``re_lower`` (lower recall + bounds). - threshold : :py:class:`list` - A list of thresholds to graph with a dot for each set. Specific - threshold values do not affect "second-annotator" dataframes. + * ``threshold``: :py:class:`list` + + A threshold to graph with a dot for each set. Specific + threshold values do not affect "second-annotator" dataframes. confidence : :py:class:`bool`, Optional If set, draw confidence intervals for each line, using ``*_upper`` and @@ -160,32 +166,35 @@ def precision_recall_f1iso(label, df, threshold, confidence=True): legend = [] - for kn, kdf, kt in zip(label, df, threshold): + for name, value in data.items(): + + df = value["df"] + threshold = value["threshold"] # plots only from the point where recall reaches its maximum, # otherwise, we don't see a curve... - max_recall = kdf["recall"].idxmax() - pi = kdf.precision[max_recall:] - ri = kdf.recall[max_recall:] + max_recall = df["recall"].idxmax() + pi = df.precision[max_recall:] + ri = df.recall[max_recall:] valid = (pi + ri) > 0 f1 = 2 * (pi[valid] * ri[valid]) / (pi[valid] + ri[valid]) # optimal point along the curve - bins = len(kdf) - index = int(round(bins*kt)) - index = min(index, len(kdf)-1) #avoids out of range indexing + bins = len(df) + index = int(round(bins*threshold)) + index = min(index, len(df)-1) #avoids out of range indexing # plots Recall/Precision as threshold changes - label = f"{kn} (F1={kdf.f1_score[index]:.4f})" + label = f"{name} (F1={df.f1_score[index]:.4f})" color = next(colorcycler) - if len(kdf) == 1: + if len(df) == 1: # plot black dot for F1-score at select threshold - marker, = axes.plot(kdf.recall[index], kdf.precision[index], + marker, = axes.plot(df.recall[index], df.precision[index], marker="*", markersize=6, color=color, alpha=0.8, linestyle="None") - line, = axes.plot(kdf.recall[index], kdf.precision[index], + line, = axes.plot(df.recall[index], df.precision[index], linestyle="None", color=color, alpha=0.2) legend.append(([marker, line], label)) else: @@ -193,17 +202,17 @@ def precision_recall_f1iso(label, df, threshold, confidence=True): style = next(linecycler) line, = axes.plot(ri[pi > 0], pi[pi > 0], color=color, linestyle=style) - marker, = axes.plot(kdf.recall[index], kdf.precision[index], + marker, = axes.plot(df.recall[index], df.precision[index], marker="o", linestyle=style, markersize=4, color=color, alpha=0.8) legend.append(([marker, line], label)) if confidence: - pui = kdf.pr_upper[max_recall:] - pli = kdf.pr_lower[max_recall:] - rui = kdf.re_upper[max_recall:] - rli = kdf.re_lower[max_recall:] + pui = df.pr_upper[max_recall:] + pli = df.pr_lower[max_recall:] + rui = df.re_upper[max_recall:] + rli = df.re_lower[max_recall:] # Plot confidence # Upper bound @@ -212,7 +221,7 @@ def precision_recall_f1iso(label, df, threshold, confidence=True): vert_y = numpy.concatenate((pui[pui > 0], pli[pli > 0][::-1])) # hacky workaround to plot 2nd human - if len(kdf) == 1: #binary system, very likely + if len(df) == 1: #binary system, very likely logger.warning("Found 2nd human annotator - patching...") p, = axes.plot(vert_x, vert_y, color=color, alpha=0.1, lw=3) else: diff --git a/bob/ip/binseg/utils/summary.py b/bob/ip/binseg/utils/summary.py index c5222f1a..7788cb7a 100644 --- a/bob/ip/binseg/utils/summary.py +++ b/bob/ip/binseg/utils/summary.py @@ -9,7 +9,7 @@ from torch.nn.modules.module import _addindent def summary(model): - """Counts the number of paramters in each model layer + """Counts the number of parameters in each model layer Parameters ---------- diff --git a/doc/api.rst b/doc/api.rst index 7e680bc4..edd9b150 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -88,6 +88,7 @@ Toolbox bob.ip.binseg.utils.model_serialization bob.ip.binseg.utils.model_zoo bob.ip.binseg.utils.plot + bob.ip.binseg.utils.table bob.ip.binseg.utils.summary diff --git a/doc/links.rst b/doc/links.rst index ee753a18..c11297d8 100644 --- a/doc/links.rst +++ b/doc/links.rst @@ -7,6 +7,7 @@ .. _installation: https://www.idiap.ch/software/bob/install .. _mailing list: https://www.idiap.ch/software/bob/discuss .. _pytorch: https://pytorch.org +.. _tabulate: https://pypi.org/project/tabulate/ .. _our paper: https://arxiv.org/abs/1909.03856 .. Raw data websites -- GitLab