From 40985094e9251eec30d63f495a19de2e5f2e7530 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Fri, 24 Apr 2020 15:53:16 +0200
Subject: [PATCH] [script.compare] Implement system performance tabulation

---
 bob/ip/binseg/script/compare.py    | 94 ++++++++++++++++++++----------
 bob/ip/binseg/script/experiment.py |  8 ++-
 bob/ip/binseg/test/test_cli.py     | 10 ++++
 bob/ip/binseg/utils/plot.py        | 79 ++++++++++++++-----------
 bob/ip/binseg/utils/summary.py     |  2 +-
 doc/api.rst                        |  1 +
 doc/links.rst                      |  1 +
 7 files changed, 127 insertions(+), 68 deletions(-)

diff --git a/bob/ip/binseg/script/compare.py b/bob/ip/binseg/script/compare.py
index e27b6296..c3b8d715 100644
--- a/bob/ip/binseg/script/compare.py
+++ b/bob/ip/binseg/script/compare.py
@@ -9,8 +9,10 @@ from bob.extension.scripts.click_helper import (
 )
 
 import pandas
+import tabulate
 
 from ..utils.plot import precision_recall_f1iso
+from ..utils.table import performance_table
 
 import logging
 logger = logging.getLogger(__name__)
@@ -44,7 +46,7 @@ def _validate_threshold(t, dataset):
     return t
 
 
-def _load_and_plot(data, threshold=None):
+def _load(data, threshold=None):
     """Plots comparison chart of all evaluated models
 
     Parameters
@@ -55,20 +57,26 @@ def _load_and_plot(data, threshold=None):
         paths to ``metrics.csv`` style files.
 
     threshold : :py:class:`float`, :py:class:`str`, Optional
-        A value indicating which threshold to choose for plotting a "F1-score"
-        (black) dot on the various curves.  If set to ``None``, then plot the
-        maximum F1-score on that curve.  If set to a floating-point value, then
-        plot the F1-score that is obtained on that particular threshold.  If
-        set to a string, it should match one of the keys in ``data``.  It then
-        first calculate the threshold reaching the maximum F1-score on that
-        particular dataset and then applies that threshold to all other sets.
+        A value indicating which threshold to choose for selecting a "F1-score"
+        If set to ``None``, then use the maximum F1-score on that metrics file.
+        If set to a floating-point value, then use the F1-score that is
+        obtained on that particular threshold.  If set to a string, it should
+        match one of the keys in ``data``.  It then first calculate the
+        threshold reaching the maximum F1-score on that particular dataset and
+        then applies that threshold to all other sets.
 
 
     Returns
     -------
 
-    figure : matplotlib.figure.Figure
-        A figure, with all systems combined into a single plot.
+    data : dict
+        A dict in which keys are the names of the systems and the values are
+        dictionaries that contain two keys:
+
+        * ``df``: A :py:class:`pandas.DataFrame` with the metrics data loaded
+          to
+        * ``threshold``: A threshold to be used for summarization, depending on
+          the ``threshold`` parameter set on the input
 
     """
 
@@ -78,8 +86,8 @@ def _load_and_plot(data, threshold=None):
         metrics_path = data[threshold]
         df = pandas.read_csv(metrics_path)
 
-        maxf1 = df["f1_score"].max()
-        use_threshold = df["threshold"][df["f1_score"].idxmax()]
+        maxf1 = df.f1_score.max()
+        use_threshold = df.threshold[df.f1_score.idxmax()]
         logger.info(f"Dataset '*': threshold = {use_threshold:.3f}'")
 
     elif isinstance(threshold, float):
@@ -91,20 +99,19 @@ def _load_and_plot(data, threshold=None):
     thresholds = []
 
     # loads all data
+    retval = {}
     for name, metrics_path in data.items():
 
         logger.info(f"Loading metrics from {metrics_path}...")
         df = pandas.read_csv(metrics_path)
 
         if threshold is None:
-            use_threshold = df["threshold"][df["f1_score"].idxmax()]
+            use_threshold = df.threshold[df.f1_score.idxmax()]
             logger.info(f"Dataset '{name}': threshold = {use_threshold:.3f}'")
 
-        names.append(name)
-        dfs.append(df)
-        thresholds.append(use_threshold)
+        retval[name] = dict(df=df, threshold=use_threshold)
 
-    return precision_recall_f1iso(names, dfs, thresholds, confidence=True)
+    return retval
 
 
 @click.command(
@@ -121,15 +128,32 @@ def _load_and_plot(data, threshold=None):
         nargs=-1,
         )
 @click.option(
-    "--output",
-    "-o",
-    help="Path where write the output figure (PDF format)",
+    "--output-figure",
+    "-f",
+    help="Path where write the output figure (any extension supported by "
+    "matplotlib is possible).  If not provided, does not produce a figure.",
+    required=False,
+    default=None,
+    type=click.Path(dir_okay=False, file_okay=True),
+)
+@click.option(
+    "--table-format",
+    "-T",
+    help="The format to use for the comparison table",
     show_default=True,
     required=True,
-    default="comparison.pdf",
-    type=click.Path(),
+    default="rst",
+    type=click.Choice(tabulate.tabulate_formats),
+)
+@click.option(
+    "--output-table",
+    "-u",
+    help="Path where write the output table. If not provided, does not write "
+    "write a table to file, only to stdout.",
+    required=False,
+    default=None,
+    type=click.Path(dir_okay=False, file_okay=True),
 )
-
 @click.option(
     "--threshold",
     "-t",
@@ -145,7 +169,8 @@ def _load_and_plot(data, threshold=None):
     required=False,
 )
 @verbosity_option()
-def compare(label_path, output, threshold, **kwargs):
+def compare(label_path, output_figure, table_format, output_table, threshold,
+        **kwargs):
     """Compares multiple systems together"""
 
     # hack to get a dictionary from arguments passed to input
@@ -156,9 +181,18 @@ def compare(label_path, output, threshold, **kwargs):
 
     threshold = _validate_threshold(threshold, data)
 
-    fig = _load_and_plot(data, threshold=threshold)
-    logger.info(f"Saving plot at {output}")
-    fig.savefig(output)
-
-    # TODO: print table with all results
-    pass
+    # load all data metrics
+    data = _load(data, threshold=threshold)
+
+    if output_figure is not None:
+        logger.info(f"Creating and saving plot at {output_figure}...")
+        fig = precision_recall_f1iso(data, confidence=True)
+        fig.savefig(output_figure)
+
+    logger.info("Tabulating performance summary...")
+    table = performance_table(data, table_format)
+    click.echo(table)
+    if output_table is not None:
+        logger.info(f"Saving table at {output_table}...")
+        with open(output_table, "wt") as f:
+            f.write(table)
diff --git a/bob/ip/binseg/script/experiment.py b/bob/ip/binseg/script/experiment.py
index 54580368..e15f8588 100644
--- a/bob/ip/binseg/script/experiment.py
+++ b/bob/ip/binseg/script/experiment.py
@@ -412,7 +412,11 @@ def experiment(
                 continue
             systems += [f"{k} (2nd. annot.)", os.path.join(analysis_folder, k,
                 "metrics-second-annotator.csv")]
-    output_pdf = os.path.join(output_folder, "comparison.pdf")
-    ctx.invoke(compare, label_path=systems, output=output_pdf, verbose=verbose)
+
+    output_figure = os.path.join(output_folder, "comparison.pdf")
+    output_table = os.path.join(output_folder, "comparison.rst")
+
+    ctx.invoke(compare, label_path=systems, output_figure=output_figure,
+            output_table=output_table, verbose=verbose)
 
     logger.info("Ended comparison, and the experiment - bye.")
diff --git a/bob/ip/binseg/test/test_cli.py b/bob/ip/binseg/test/test_cli.py
index 6e2d8be2..8acd7f87 100644
--- a/bob/ip/binseg/test/test_cli.py
+++ b/bob/ip/binseg/test/test_cli.py
@@ -158,6 +158,7 @@ def _check_experiment_stare(overlay):
 
         # check outcomes of the comparison phase
         assert os.path.exists(os.path.join(output_folder, "comparison.pdf"))
+        assert os.path.exists(os.path.join(output_folder, "comparison.rst"))
 
         keywords = {
             r"^Started training$": 1,
@@ -175,6 +176,9 @@ def _check_experiment_stare(overlay):
             r"^Ended evaluation$": 1,
             r"^Started comparison$": 1,
             r"^Loading metrics from": 4,
+            r"^Creating and saving plot at": 1,
+            r"^Tabulating performance summary...": 1,
+            r"^Saving table at": 1,
             r"^Ended comparison.*$": 1,
         }
         buf.seek(0)
@@ -399,14 +403,20 @@ def _check_compare(runner):
                 os.path.join(output_folder, "metrics.csv"),
                 "test (2nd. human)",
                 os.path.join(output_folder, "metrics-second-annotator.csv"),
+                "--output-figure=comparison.pdf",
+                "--output-table=comparison.rst",
             ],
         )
         _assert_exit_0(result)
 
         assert os.path.exists("comparison.pdf")
+        assert os.path.exists("comparison.rst")
 
         keywords = {
             r"^Loading metrics from": 2,
+            r"^Creating and saving plot at": 1,
+            r"^Tabulating performance summary...": 1,
+            r"^Saving table at": 1,
         }
         buf.seek(0)
         logging_output = buf.read()
diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py
index 0bbea34a..9be1b0ed 100644
--- a/bob/ip/binseg/utils/plot.py
+++ b/bob/ip/binseg/utils/plot.py
@@ -97,35 +97,41 @@ def _precision_recall_canvas(title=None):
     plt.tight_layout()
 
 
-def precision_recall_f1iso(label, df, threshold, confidence=True):
+def precision_recall_f1iso(data, confidence=True):
     """Creates a precision-recall plot with confidence intervals
 
     This function creates and returns a Matplotlib figure with a
-    precision-recall plot containing shaded confidence intervals.  The plot
-    will be annotated with F1-score iso-lines (in which the F1-score maintains
-    the same value).
+    precision-recall plot containing shaded confidence intervals (standard
+    deviation on the precision-recall measurements).  The plot will be
+    annotated with F1-score iso-lines (in which the F1-score maintains the same
+    value).
+
+    This function specially supports "second-annotator" entries by plotting a
+    line showing the comparison between the default annotator being analyzed
+    and a second "opinion".  Second annotator dataframes contain a single
+    entry (threshold=0.5), given the nature of the binary map comparisons.
 
 
     Parameters
     ----------
 
-    label : :py:class:`list`
-        A list of names to be associated to each line
+    data : dict
+        A dictionary in which keys are strings defining plot labels and values
+        are dictionaries with two entries:
 
-    df : :py:class:`pandas.DataFrame`
-        A dataframe that is produced by our evaluator engine, indexed by
-        integer "thresholds", containing the following columns: ``threshold``
-        (sorted ascending), ``precision``, ``recall``, ``pr_upper`` (upper
-        precision bounds), ``pr_lower`` (lower precision bounds), ``re_upper``
-        (upper recall bounds), ``re_lower`` (lower recall bounds).
+        * ``df``: :py:class:`pandas.DataFrame`
 
-        Dataframes with a single entry are treated specially as these are
-        considered "second-annotator" performances.  A single dot and a line
-        showing the variability is drawn in these cases.
+          A dataframe that is produced by our evaluator engine, indexed by
+          integer "thresholds", containing the following columns: ``threshold``
+          (sorted ascending), ``precision``, ``recall``, ``pr_upper`` (upper
+          precision bounds), ``pr_lower`` (lower precision bounds),
+          ``re_upper`` (upper recall bounds), ``re_lower`` (lower recall
+          bounds).
 
-    threshold : :py:class:`list`
-        A list of thresholds to graph with a dot for each set.  Specific
-        threshold values do not affect "second-annotator" dataframes.
+        * ``threshold``: :py:class:`list`
+
+          A threshold to graph with a dot for each set.    Specific
+          threshold values do not affect "second-annotator" dataframes.
 
     confidence : :py:class:`bool`, Optional
         If set, draw confidence intervals for each line, using ``*_upper`` and
@@ -160,32 +166,35 @@ def precision_recall_f1iso(label, df, threshold, confidence=True):
 
         legend = []
 
-        for kn, kdf, kt in zip(label, df, threshold):
+        for name, value in data.items():
+
+            df = value["df"]
+            threshold = value["threshold"]
 
             # plots only from the point where recall reaches its maximum,
             # otherwise, we don't see a curve...
-            max_recall = kdf["recall"].idxmax()
-            pi = kdf.precision[max_recall:]
-            ri = kdf.recall[max_recall:]
+            max_recall = df["recall"].idxmax()
+            pi = df.precision[max_recall:]
+            ri = df.recall[max_recall:]
 
             valid = (pi + ri) > 0
             f1 = 2 * (pi[valid] * ri[valid]) / (pi[valid] + ri[valid])
 
             # optimal point along the curve
-            bins = len(kdf)
-            index = int(round(bins*kt))
-            index = min(index, len(kdf)-1)  #avoids out of range indexing
+            bins = len(df)
+            index = int(round(bins*threshold))
+            index = min(index, len(df)-1)  #avoids out of range indexing
 
             # plots Recall/Precision as threshold changes
-            label = f"{kn} (F1={kdf.f1_score[index]:.4f})"
+            label = f"{name} (F1={df.f1_score[index]:.4f})"
             color = next(colorcycler)
 
-            if len(kdf) == 1:
+            if len(df) == 1:
                 # plot black dot for F1-score at select threshold
-                marker, = axes.plot(kdf.recall[index], kdf.precision[index],
+                marker, = axes.plot(df.recall[index], df.precision[index],
                         marker="*", markersize=6, color=color, alpha=0.8,
                         linestyle="None")
-                line, = axes.plot(kdf.recall[index], kdf.precision[index],
+                line, = axes.plot(df.recall[index], df.precision[index],
                         linestyle="None", color=color, alpha=0.2)
                 legend.append(([marker, line], label))
             else:
@@ -193,17 +202,17 @@ def precision_recall_f1iso(label, df, threshold, confidence=True):
                 style = next(linecycler)
                 line, = axes.plot(ri[pi > 0], pi[pi > 0], color=color,
                         linestyle=style)
-                marker, = axes.plot(kdf.recall[index], kdf.precision[index],
+                marker, = axes.plot(df.recall[index], df.precision[index],
                         marker="o", linestyle=style, markersize=4,
                         color=color, alpha=0.8)
                 legend.append(([marker, line], label))
 
             if confidence:
 
-                pui = kdf.pr_upper[max_recall:]
-                pli = kdf.pr_lower[max_recall:]
-                rui = kdf.re_upper[max_recall:]
-                rli = kdf.re_lower[max_recall:]
+                pui = df.pr_upper[max_recall:]
+                pli = df.pr_lower[max_recall:]
+                rui = df.re_upper[max_recall:]
+                rli = df.re_lower[max_recall:]
 
                 # Plot confidence
                 # Upper bound
@@ -212,7 +221,7 @@ def precision_recall_f1iso(label, df, threshold, confidence=True):
                 vert_y = numpy.concatenate((pui[pui > 0], pli[pli > 0][::-1]))
 
                 # hacky workaround to plot 2nd human
-                if len(kdf) == 1:  #binary system, very likely
+                if len(df) == 1:  #binary system, very likely
                     logger.warning("Found 2nd human annotator - patching...")
                     p, = axes.plot(vert_x, vert_y, color=color, alpha=0.1, lw=3)
                 else:
diff --git a/bob/ip/binseg/utils/summary.py b/bob/ip/binseg/utils/summary.py
index c5222f1a..7788cb7a 100644
--- a/bob/ip/binseg/utils/summary.py
+++ b/bob/ip/binseg/utils/summary.py
@@ -9,7 +9,7 @@ from torch.nn.modules.module import _addindent
 
 
 def summary(model):
-    """Counts the number of paramters in each model layer
+    """Counts the number of parameters in each model layer
 
     Parameters
     ----------
diff --git a/doc/api.rst b/doc/api.rst
index 7e680bc4..edd9b150 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -88,6 +88,7 @@ Toolbox
    bob.ip.binseg.utils.model_serialization
    bob.ip.binseg.utils.model_zoo
    bob.ip.binseg.utils.plot
+   bob.ip.binseg.utils.table
    bob.ip.binseg.utils.summary
 
 
diff --git a/doc/links.rst b/doc/links.rst
index ee753a18..c11297d8 100644
--- a/doc/links.rst
+++ b/doc/links.rst
@@ -7,6 +7,7 @@
 .. _installation: https://www.idiap.ch/software/bob/install
 .. _mailing list: https://www.idiap.ch/software/bob/discuss
 .. _pytorch: https://pytorch.org
+.. _tabulate: https://pypi.org/project/tabulate/
 .. _our paper: https://arxiv.org/abs/1909.03856
 
 .. Raw data websites
-- 
GitLab