diff --git a/bob/measure/script/figure.py b/bob/measure/script/figure.py index c88f5c4968c8937e857f858cbefd850b588a17e7..b0774456315051a2ff2bbd3f5e57379f9fbbb1c0 100644 --- a/bob/measure/script/figure.py +++ b/bob/measure/script/figure.py @@ -4,11 +4,11 @@ from __future__ import division, print_function from abc import ABCMeta, abstractmethod import math import sys -import os.path import numpy import click import matplotlib import matplotlib.pyplot as mpl +from matplotlib import gridspec from matplotlib.backends.backend_pdf import PdfPages from tabulate import tabulate from .. import (far_threshold, plot, utils, ppndf) @@ -707,7 +707,93 @@ class Epc(PlotBase): ) -class Hist(PlotBase): +class GridSubplot(PlotBase): + """A base class for plots that contain subplots and legends. This is needed + because "constrained_layout will not work on subplots created via the + subplot command." + + To use this class, use `create_subplot` in `compute` each time you need a + new axis. and call `finalize_one_page` in `compute` when a page is finished + rendering. + """ + + def __init__(self, ctx, scores, evaluation, func_load): + super(GridSubplot, self).__init__(ctx, scores, evaluation, func_load) + + # Check legend + self._legend_loc = self._legend_loc or 'upper center' + if self._legend_loc == 'best': + self._legend_loc = 'upper center' + if 'upper' not in self._legend_loc and \ + 'lower' not in self._legend_loc: + raise ValueError('Only best, (upper *), and (lower-*) legend ' + 'locations are supported!') + if 'up' in self._legend_loc: + self._legend_grid_axis_number = 0 + self._grid_axis_offset = 1 + else: + self._legend_grid_axis_number = -1 + self._grid_axis_offset = 0 + + # subplot grid + self._nrows = ctx.meta.get('n_row', 1) + self._ncols = ctx.meta.get('n_col', 1) + + # GridSpec artificial rows and cols multipliers + self._row_times = 8 + self._col_times = 2 + + def init_process(self): + super(GridSubplot, self).init_process() + self._create_grid_spec() + + def _create_grid_spec(self): + # create a compatible GridSpec + self._gs = gridspec.GridSpec( + self._nrows * self._row_times + 1, + self._ncols * self._col_times, + figure=mpl.gcf()) + + def create_subplot(self, n): + i, j = numpy.unravel_index(n, (self._nrows, self._ncols)) + i1 = i * self._row_times + self._grid_axis_offset + i2 = (i + 1) * self._row_times + self._grid_axis_offset + j1, j2 = j * self._col_times, (j + 1) * self._col_times + axis = mpl.gcf().add_subplot(self._gs[i1:i2, j1:j2]) + return axis + + def finalize_one_page(self): + # print legend on the page + self.plot_legends() + self._pdf_page.savefig(bbox_inches='tight') + mpl.clf() + mpl.figure() + self._create_grid_spec() + + def plot_legends(self): + ''' Print legend on current page''' + lines = [] + labels = [] + for ax in mpl.gcf().get_axes(): + ali, ala = ax.get_legend_handles_labels() + # avoid duplicates in legend + for li, la in zip(ali, ala): + if la not in labels: + lines.append(li) + labels.append(la) + + if self._disp_legend: + # create legend on the top or bottom axis + ax = mpl.gcf().add_subplot( + self._gs[self._legend_grid_axis_number, :]) + # right, left, or center + loc = self._legend_loc.split()[1] + ax.legend(lines, labels, loc=loc, ncol=self._nlegends) + # don't show its axis + ax.set_axis_off() + + +class Hist(GridSubplot): ''' Functional base class for histograms''' def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2): @@ -723,9 +809,6 @@ class Hist(PlotBase): self._criterion = ctx.meta.get('criterion') # no vertical (threshold) is displayed self._no_line = ctx.meta.get('no_line', False) - # subplot grid - self._nrows = ctx.meta.get('n_row', 1) - self._ncols = ctx.meta.get('n_col', 1) # do not display dev histo self._hide_dev = ctx.meta.get('hide_dev', False) if self._hide_dev and not self._eval: @@ -734,7 +817,7 @@ class Hist(PlotBase): # dev hist are displayed next to eval hist self._nrows *= 1 if self._hide_dev or not self._eval else 2 self._nlegends = ctx.meta.get('legends_ncol', 3) - self._legend_loc = self._legend_loc or 'upper center' + # number of subplot on one page self._step_print = int(self._nrows * self._ncols) self._title_base = 'Scores' @@ -766,12 +849,13 @@ class Hist(PlotBase): self._print_subplot(idx, sys, eval_neg, eval_pos, threshold, not self._no_line, True) - def _print_subplot(self, idx, sys, neg, pos, threshold, draw_line, evaluation): + def _print_subplot(self, idx, sys, neg, pos, threshold, draw_line, + evaluation): ''' print a subplot for the given score and subplot index''' n = idx % self._step_print col = n % self._ncols sub_plot_idx = n + 1 - axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx) + axis = self.create_subplot(n) self._setup_hist(neg, pos) if col == 0: axis.set_ylabel(self._y_label) @@ -805,12 +889,7 @@ class Hist(PlotBase): # to display, save figure if self._step_print == sub_plot_idx or (is_lower and sys == self.n_systems - 1): - # print legend on the page - self.plot_legends() - mpl.tight_layout() - self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight') - mpl.clf() - mpl.figure() + self.finalize_one_page() def _get_title(self, idx, dflt=None): ''' Get the histo title for the given idx''' @@ -821,25 +900,6 @@ class Hist(PlotBase): ' ', '') else title return title or '' - def plot_legends(self): - ''' Print legend on current page''' - lines = [] - labels = [] - for ax in mpl.gcf().get_axes(): - ali, ala = ax.get_legend_handles_labels() - # avoid duplicates in legend - for li, la in zip(ali, ala): - if la not in labels: - lines.append(li) - labels.append(la) - - if self._disp_legend: - mpl.gcf().legend( - lines, labels, loc=self._legend_loc, fancybox=True, - framealpha=0.5, ncol=self._nlegends, - bbox_to_anchor=(0.55, 1.1), - ) - def _get_neg_pos_thres(self, idx, input_scores, input_names): ''' Get scores and threshod for the given system at index idx''' neg_list, pos_list, _ = utils.get_fta_list(input_scores)