diff --git a/bob/measure/script/figure.py b/bob/measure/script/figure.py index 768d0f738a2095246bbfe3a95b33e8a55629fd91..031ae9ed982d0d8fc6b5dd9df993e7813d33dd10 100644 --- a/bob/measure/script/figure.py +++ b/bob/measure/script/figure.py @@ -74,7 +74,7 @@ class MeasureBase(object): ) def run(self): - """ Generate outputs (e.g. metrics, files, pdf plots). + """Generate outputs (e.g. metrics, files, pdf plots). This function calls abstract methods :func:`~bob.measure.script.figure.MeasureBase.init_process` (before loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute` @@ -117,7 +117,7 @@ class MeasureBase(object): # protected functions that need to be overwritten def init_process(self): - """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run + """Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run before iterating through the different systems. Should reimplemented in derived classes""" pass @@ -149,7 +149,7 @@ class MeasureBase(object): # Things to do after the main iterative computations are done @abstractmethod def end_process(self): - """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run + """Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run after iterating through the different systems. Should reimplemented in derived classes""" pass @@ -157,7 +157,7 @@ class MeasureBase(object): # common protected functions def _load_files(self, filepaths): - """ Load the input files and return the base names of the files + """Load the input files and return the base names of the files Returns ------- @@ -176,7 +176,7 @@ class MeasureBase(object): class Metrics(MeasureBase): - """ Compute metrics from score files + """Compute metrics from score files Attributes ---------- @@ -347,7 +347,7 @@ class Metrics(MeasureBase): return res def compute(self, idx, input_scores, input_names): - """ Compute metrics thresholds and tables (FPR, FNR, precision, recall, + """Compute metrics thresholds and tables (FPR, FNR, precision, recall, f1_score) for given system inputs""" dev_file = input_names[0] title = self._legends[idx] if self._legends is not None else None @@ -509,7 +509,7 @@ class MultiMetrics(Metrics): class PlotBase(MeasureBase): - """ Base class for plots. Regroup several options and code + """Base class for plots. Regroup several options and code shared by the different plots """ @@ -573,8 +573,8 @@ class PlotBase(MeasureBase): fig.clear() def end_process(self): - """ Set title, legend, axis labels, grid colors, save figures, drow - lines and close pdf if needed """ + """Set title, legend, axis labels, grid colors, save figures, drow + lines and close pdf if needed""" # draw vertical lines if self._far_at is not None: for (line, line_trans) in zip(self._far_at, self._trans_far_val): @@ -642,7 +642,7 @@ class Roc(PlotBase): self._min_dig = -4 if self._min_dig is None else self._min_dig def compute(self, idx, input_scores, input_names): - """ Plot ROC for dev and eval data using + """Plot ROC for dev and eval data using :py:func:`bob.measure.plot.roc`""" neg_list, pos_list, _ = utils.get_fta_list(input_scores) dev_neg, dev_pos = neg_list[0], pos_list[0] @@ -733,7 +733,7 @@ class Det(PlotBase): self._min_dig = -4 if self._min_dig is None else self._min_dig def compute(self, idx, input_scores, input_names): - """ Plot DET for dev and eval data using + """Plot DET for dev and eval data using :py:func:`bob.measure.plot.det`""" neg_list, pos_list, _ = utils.get_fta_list(input_scores) dev_neg, dev_pos = neg_list[0], pos_list[0] @@ -835,9 +835,7 @@ class Epc(PlotBase): class GridSubplot(PlotBase): - """A base class for plots that contain subplots and legends. This is needed - because "constrained_layout will not work on subplots created via the - subplot command." + """A base class for plots that contain subplots and legends. To use this class, use `create_subplot` in `compute` each time you need a new axis. and call `finalize_one_page` in `compute` when a page is finished @@ -853,23 +851,14 @@ class GridSubplot(PlotBase): self._legend_loc = "upper center" if "upper" not in self._legend_loc and "lower" not in self._legend_loc: raise ValueError( - "Only best, (upper *), and (lower-*) legend " "locations are supported!" + "Only best, upper-*, and lower-* legend locations are supported!" ) - if "up" in self._legend_loc: - self._legend_grid_axis_number = 0 - self._grid_axis_offset = 1 - else: - self._legend_grid_axis_number = -1 - self._grid_axis_offset = 0 + self._nlegends = ctx.meta.get("legends_ncol", 3) # subplot grid self._nrows = ctx.meta.get("n_row", 1) self._ncols = ctx.meta.get("n_col", 1) - # GridSpec artificial rows and cols multipliers - self._row_times = 8 - self._col_times = 2 - def init_process(self): super(GridSubplot, self).init_process() self._create_grid_spec() @@ -877,29 +866,36 @@ class GridSubplot(PlotBase): def _create_grid_spec(self): # create a compatible GridSpec self._gs = gridspec.GridSpec( - self._nrows * self._row_times + 1, - self._ncols * self._col_times, + self._nrows, + self._ncols, figure=mpl.gcf(), ) def create_subplot(self, n, shared_axis=None): i, j = numpy.unravel_index(n, (self._nrows, self._ncols)) - i1 = i * self._row_times + self._grid_axis_offset - i2 = (i + 1) * self._row_times + self._grid_axis_offset - j1, j2 = j * self._col_times, (j + 1) * self._col_times - axis = mpl.gcf().add_subplot(self._gs[i1:i2, j1:j2], sharex=shared_axis) + axis = mpl.gcf().add_subplot(self._gs[i : i + 1, j : j + 1], sharex=shared_axis) return axis def finalize_one_page(self): # print legend on the page self.plot_legends() + fig = mpl.gcf() + axes = fig.get_axes() + + LOGGER.debug("%s contains %d axes:", fig, len(axes)) + for i, ax in enumerate(axes, start=1): + LOGGER.debug("Axes %d: %s", i, ax) + self._pdf_page.savefig(bbox_inches="tight") mpl.clf() mpl.figure() self._create_grid_spec() def plot_legends(self): - """ Print legend on current page""" + """Print legend on current page""" + if not self._disp_legend: + return + lines = [] labels = [] for ax in mpl.gcf().get_axes(): @@ -910,14 +906,27 @@ class GridSubplot(PlotBase): lines.append(li) labels.append(la) - if self._disp_legend: - # create legend on the top or bottom axis - ax = mpl.gcf().add_subplot(self._gs[self._legend_grid_axis_number, :]) - # right, left, or center - loc = self._legend_loc.split()[1] - ax.legend(lines, labels, loc=loc, ncol=self._nlegends) - # don't show its axis - ax.set_axis_off() + # create legend on the top or bottom axis + fig = mpl.gcf() + if "upper" in self._legend_loc: + # Set anchor to top of figure + bbox_to_anchor = (0.0, 1.0, 1.0, 0.0) + # Legend will be anchored with its bottom side, so switch the loc + anchored_loc = self._legend_loc.replace("upper", "lower") + else: + # Set anchor to bottom of figure + bbox_to_anchor = (0.0, 0.0, 1.0, 0.0) + # Legend will be anchored with its top side, so switch the loc + anchored_loc = self._legend_loc.replace("lower", "upper") + leg = fig.legend( + lines, + labels, + loc=anchored_loc, + ncol=self._nlegends, + bbox_to_anchor=bbox_to_anchor, + ) + + return leg class Hist(GridSubplot): @@ -971,7 +980,13 @@ class Hist(GridSubplot): if not self._hide_dev or not self._eval: dev_axis = self._print_subplot( - idx, sys, dev_neg, dev_pos, threshold, not self._no_line, False, + idx, + sys, + dev_neg, + dev_pos, + threshold, + not self._no_line, + False, ) if self._eval: @@ -1084,7 +1099,7 @@ class Hist(GridSubplot): mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs) def _setup_hist(self, neg, pos): - """ This function can be overwritten in derived classes + """This function can be overwritten in derived classes Plots all the density histo required in one plot. Here negative and positive scores densities.