From e1e9a249ff4959c2ca9d23eab01ff0cec417f929 Mon Sep 17 00:00:00 2001
From: Theophile GENTILHOMME <tgentilhomme@jurasix08.idiap.ch>
Date: Mon, 4 Jun 2018 12:14:41 +0200
Subject: [PATCH] [script][figure] Fix det and roc plots

Add IAPMR second x axis
---
 bob/pad/base/script/figure.py | 73 ++++++++++++++++++++++++-----------
 1 file changed, 51 insertions(+), 22 deletions(-)

diff --git a/bob/pad/base/script/figure.py b/bob/pad/base/script/figure.py
index 63ad5fb..f09359b 100644
--- a/bob/pad/base/script/figure.py
+++ b/bob/pad/base/script/figure.py
@@ -8,7 +8,7 @@ import bob.bio.base.script.figure as bio_figure
 from tabulate import tabulate
 from bob.measure.utils import get_fta_list
 from bob.measure import (
-    far_threshold, eer_threshold, min_hter_threshold, farfrr, epc, ppndf
+    frr_threshold, far_threshold, eer_threshold, min_hter_threshold, farfrr, epc, ppndf
 )
 from bob.measure.plot import (det, det_axis, roc_for_far, log_values)
 from . import error_utils
@@ -543,6 +543,7 @@ class BaseDetRoc(PadPlot):
         self._no_spoof = no_spoof
         self._criteria = criteria or 'eer'
         self._real_data = True if real_data is None else real_data
+        self._legend_loc = None
 
 
     def compute(self, idx, input_scores, input_names):
@@ -562,6 +563,13 @@ class BaseDetRoc(PadPlot):
             label=self._label("licit", input_names[0], idx)
         )
         if not self._no_spoof and spoof_eval_neg is not None:
+            ax1 = mpl.gca()
+            ax2 = ax1.twiny()
+            ax2.set_xlabel('IAPMR', color='C3')
+            ax2.set_xticklabels(ax2.get_xticks())
+            ax2.tick_params(axis='x', colors='C3')
+            ax2.xaxis.label.set_color('C3')
+            ax2.spines['top'].set_color('C3')
             self._plot(
                 spoof_eval_neg,
                 spoof_eval_pos,
@@ -570,6 +578,7 @@ class BaseDetRoc(PadPlot):
                 linestyle=':',
                 label=self._label("spoof", input_names[3], idx)
             )
+            mpl.sca(ax1)
 
         if self._criteria is None or self._no_spoof:
             return
@@ -587,8 +596,11 @@ class BaseDetRoc(PadPlot):
         if farfrr_licit is None:
             return
         farfrr_spoof, farfrr_spoof_det = self._get_farfrr(
-            spoof_eval_neg, spoof_eval_pos, thres_baseline
+            spoof_eval_neg, spoof_eval_pos,
+            frr_threshold(spoof_eval_neg, spoof_eval_pos,
+                                      farfrr_licit[1])
         )
+
         if not self._real_data:
             mpl.axhline(
                 y=farfrr_licit_det[1],
@@ -621,14 +633,24 @@ class BaseDetRoc(PadPlot):
         )  # FAR point, spoof scenario
 
         # annotate the FAR points
-        xyannotate_licit = [
-            0.15 + farfrr_licit_det[0],
-            farfrr_licit_det[1] - 0.15,
-        ]
-        xyannotate_spoof = [
-            0.15 + farfrr_spoof_det[0],
-            farfrr_spoof_det[1] - 0.15,
-        ]
+        if farfrr_licit_det[0] < farfrr_spoof_det[0]:
+            xyannotate_licit = [
+                farfrr_licit_det[0] - 0.7,
+                farfrr_licit_det[1] - 0.4,
+            ]
+            xyannotate_spoof = [
+                0.1 + farfrr_spoof_det[0],
+                farfrr_spoof_det[1] + 0.3,
+            ]
+        else:
+            xyannotate_spoof = [
+                farfrr_licit_det[0] - 0.7,
+                farfrr_licit_det[1] - 0.4,
+            ]
+            xyannotate_licit = [
+                0.1 + farfrr_spoof_det[0],
+                farfrr_spoof_det[1] + 0.3,
+            ]
 
         if not self._real_data:
             mpl.annotate(
@@ -663,21 +685,26 @@ class BaseDetRoc(PadPlot):
         # only for plots
 
         if self._title.replace(' ', ''):
-            mpl.title(self._title)
+            mpl.title(self._title, y=1.15)
         mpl.xlabel(self._x_label)
         mpl.ylabel(self._y_label)
         mpl.grid(True, color=self._grid_color)
+        lines = []
+        labels = []
+        for ax in mpl.gcf().get_axes():
+            li, la = ax.get_legend_handles_labels()
+            lines += li
+            labels += la
+            mpl.sca(ax)
+            self._set_axis()
+            fig = mpl.gcf()
+            mpl.xticks(rotation=self._x_rotation)
+            mpl.tick_params(axis='both', which='major', labelsize=6)
         if self._disp_legend:
-            mpl.legend(loc=self._legend_loc)
-        self._set_axis()
-        fig = mpl.gcf()
-        mpl.xticks(rotation=self._x_rotation)
-        mpl.tick_params(axis='both', which='major', labelsize=4)
-        for tick in mpl.gca().xaxis.get_major_ticks():
-            tick.label.set_fontsize(6)
-        for tick in mpl.gca().yaxis.get_major_ticks():
-            tick.label.set_fontsize(6)
-
+            mpl.gca().legend(
+                lines, labels, loc=self._legend_loc, fancybox=True,
+                framealpha=0.5
+            )
         self._pdf_page.savefig(fig)
 
         # do not want to close PDF when running evaluate
@@ -704,6 +731,7 @@ class Det(BaseDetRoc):
         if not self._no_spoof:
             add = " and overlaid SPOOF scenario"
         self._title = self._title or ('DET: LICIT' + add)
+        self._legend_loc = self._legend_loc or 'upper right'
 
 
     def _set_axis(self):
@@ -713,7 +741,6 @@ class Det(BaseDetRoc):
             det_axis([0.01, 99, 0.01, 99])
 
     def _get_farfrr(self, x, y, thres):
-        # calculate test frr @ EER (licit scenario)
         points = farfrr(x, y, thres)
         return points, [ppndf(i) for i in points]
 
@@ -740,6 +767,8 @@ class RocVuln(BaseDetRoc):
         if not self._no_spoof:
             add = " and overlaid SPOOF scenario"
         self._title = self._title or ('ROC: LICIT' + add)
+        best_legend = 'lower right' if self._semilogx else 'upper right'
+        self._legend_loc = self._legend_loc or best_legend
 
 
     def _plot(self, x, y, points, **kwargs):
-- 
GitLab