diff --git a/bob/pad/base/script/vuln_figure.py b/bob/pad/base/script/vuln_figure.py index a4c6881c3fc113e509fe1e5629316086518c8f06..7362c8481857062b77cbbca82e77d3f2b18c1263 100644 --- a/bob/pad/base/script/vuln_figure.py +++ b/bob/pad/base/script/vuln_figure.py @@ -93,9 +93,9 @@ class HistVuln(measure_figure.Hist): int(idx / self._step_print) * self._step_print if col == self._ncols - 1 or n == rest_print - 1: ax2.set_ylabel("IAPMR (%)", color='C3') - ax2.tick_params(axis='y', colors='red') - ax2.yaxis.label.set_color('red') - ax2.spines['right'].set_color('red') + ax2.tick_params(axis='y', colors='C3') + ax2.yaxis.label.set_color('C3') + ax2.spines['right'].set_color('C3') class VulnPlot(measure_figure.PlotBase): @@ -207,7 +207,7 @@ class Epc(VulnPlot): self._pdf_page.savefig(mpl.gcf()) -class Epsc(VulnPlot): +class Epsc(VulnPlot, measure_figure.GridSubplot): ''' Handles the plotting of EPSC ''' def __init__(self, ctx, scores, evaluation, func_load, @@ -221,8 +221,7 @@ class Epsc(VulnPlot): self._var_param = var_param or "omega" self._fixed_params = fixed_params or [0.5] self._titles = ctx.meta.get('titles', []) * 2 - self._legend_loc = self._legend_loc or 'upper center' - self._eval = True # always eval data with EPC + self._eval = True # always eval data with EPSC self._split = False self._nb_figs = 1 self._sampling = ctx.meta.get('sampling', 5) @@ -234,6 +233,9 @@ class Epsc(VulnPlot): raise click.BadParameter("You must provide 4 scores files:{licit," "spoof}/{dev,eval}") + self._ncols = 1 if self._iapmr else 0 + self._ncols += 1 if self._wer else 0 + def compute(self, idx, input_scores, input_names): ''' Plot EPSC for PAD''' licit_dev_neg = input_scores[0][0] @@ -252,15 +254,11 @@ class Epsc(VulnPlot): elif self.n_systems > 1: legend = 'Sys%d' % (idx + 1) - n_col = 1 if self._iapmr else 0 - n_col += 1 if self._wer else 0 - if not merge_sys or idx == 0: # axes should only be created once - mpl.figure() - self._axis1 = mpl.subplot(1, n_col, 1) - if n_col == 2: - self._axis2 = mpl.subplot(1, n_col, 2) + self._axis1 = self.create_subplot(0) + if self._ncols == 2: + self._axis2 = self.create_subplot(1) else: self._axis2 = self._axis1 points = 10 @@ -353,14 +351,7 @@ class Epsc(VulnPlot): mpl.xticks(rotation=self._x_rotation) if self._fixed_params is None or len(self._fixed_params) > 1 or \ idx == self.n_systems - 1: - # all plots share same legends - lines, labels = self._axis1.get_legend_handles_labels() - mpl.gcf().legend( - lines, labels, loc=self._legend_loc, fancybox=True, mode="expand", - framealpha=0.5, ncol=self._nlegends, bbox_to_anchor=(0., 1.12, 1., .102) - ) - mpl.tight_layout() - self._pdf_page.savefig(bbox_inches='tight') + self.finalize_one_page() class Epsc3D(Epsc):