diff --git a/bob/measure/script/figure.py b/bob/measure/script/figure.py
index 768d0f738a2095246bbfe3a95b33e8a55629fd91..031ae9ed982d0d8fc6b5dd9df993e7813d33dd10 100644
--- a/bob/measure/script/figure.py
+++ b/bob/measure/script/figure.py
@@ -74,7 +74,7 @@ class MeasureBase(object):
             )
 
     def run(self):
-        """ Generate outputs (e.g. metrics, files, pdf plots).
+        """Generate outputs (e.g. metrics, files, pdf plots).
         This function calls abstract methods
         :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
         loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
@@ -117,7 +117,7 @@ class MeasureBase(object):
 
     # protected functions that need to be overwritten
     def init_process(self):
-        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
+        """Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
         before iterating through the different systems.
         Should reimplemented in derived classes"""
         pass
@@ -149,7 +149,7 @@ class MeasureBase(object):
     # Things to do after the main iterative computations are done
     @abstractmethod
     def end_process(self):
-        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
+        """Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
         after iterating through the different systems.
         Should reimplemented in derived classes"""
         pass
@@ -157,7 +157,7 @@ class MeasureBase(object):
     # common protected functions
 
     def _load_files(self, filepaths):
-        """ Load the input files and return the base names of the files
+        """Load the input files and return the base names of the files
 
         Returns
         -------
@@ -176,7 +176,7 @@ class MeasureBase(object):
 
 
 class Metrics(MeasureBase):
-    """ Compute metrics from score files
+    """Compute metrics from score files
 
     Attributes
     ----------
@@ -347,7 +347,7 @@ class Metrics(MeasureBase):
         return res
 
     def compute(self, idx, input_scores, input_names):
-        """ Compute metrics thresholds and tables (FPR, FNR, precision, recall,
+        """Compute metrics thresholds and tables (FPR, FNR, precision, recall,
         f1_score) for given system inputs"""
         dev_file = input_names[0]
         title = self._legends[idx] if self._legends is not None else None
@@ -509,7 +509,7 @@ class MultiMetrics(Metrics):
 
 
 class PlotBase(MeasureBase):
-    """ Base class for plots. Regroup several options and code
+    """Base class for plots. Regroup several options and code
     shared by the different plots
     """
 
@@ -573,8 +573,8 @@ class PlotBase(MeasureBase):
             fig.clear()
 
     def end_process(self):
-        """ Set title, legend, axis labels, grid colors, save figures, drow
-        lines and close pdf if needed """
+        """Set title, legend, axis labels, grid colors, save figures, drow
+        lines and close pdf if needed"""
         # draw vertical lines
         if self._far_at is not None:
             for (line, line_trans) in zip(self._far_at, self._trans_far_val):
@@ -642,7 +642,7 @@ class Roc(PlotBase):
         self._min_dig = -4 if self._min_dig is None else self._min_dig
 
     def compute(self, idx, input_scores, input_names):
-        """ Plot ROC for dev and eval data using
+        """Plot ROC for dev and eval data using
         :py:func:`bob.measure.plot.roc`"""
         neg_list, pos_list, _ = utils.get_fta_list(input_scores)
         dev_neg, dev_pos = neg_list[0], pos_list[0]
@@ -733,7 +733,7 @@ class Det(PlotBase):
         self._min_dig = -4 if self._min_dig is None else self._min_dig
 
     def compute(self, idx, input_scores, input_names):
-        """ Plot DET for dev and eval data using
+        """Plot DET for dev and eval data using
         :py:func:`bob.measure.plot.det`"""
         neg_list, pos_list, _ = utils.get_fta_list(input_scores)
         dev_neg, dev_pos = neg_list[0], pos_list[0]
@@ -835,9 +835,7 @@ class Epc(PlotBase):
 
 
 class GridSubplot(PlotBase):
-    """A base class for plots that contain subplots and legends. This is needed
-    because "constrained_layout will not work on subplots created via the
-    subplot command."
+    """A base class for plots that contain subplots and legends.
 
     To use this class, use `create_subplot` in `compute` each time you need a
     new axis. and call `finalize_one_page` in `compute` when a page is finished
@@ -853,23 +851,14 @@ class GridSubplot(PlotBase):
             self._legend_loc = "upper center"
         if "upper" not in self._legend_loc and "lower" not in self._legend_loc:
             raise ValueError(
-                "Only best, (upper *), and (lower-*) legend " "locations are supported!"
+                "Only best, upper-*, and lower-* legend locations are supported!"
             )
-        if "up" in self._legend_loc:
-            self._legend_grid_axis_number = 0
-            self._grid_axis_offset = 1
-        else:
-            self._legend_grid_axis_number = -1
-            self._grid_axis_offset = 0
+        self._nlegends = ctx.meta.get("legends_ncol", 3)
 
         # subplot grid
         self._nrows = ctx.meta.get("n_row", 1)
         self._ncols = ctx.meta.get("n_col", 1)
 
-        # GridSpec artificial rows and cols multipliers
-        self._row_times = 8
-        self._col_times = 2
-
     def init_process(self):
         super(GridSubplot, self).init_process()
         self._create_grid_spec()
@@ -877,29 +866,36 @@ class GridSubplot(PlotBase):
     def _create_grid_spec(self):
         # create a compatible GridSpec
         self._gs = gridspec.GridSpec(
-            self._nrows * self._row_times + 1,
-            self._ncols * self._col_times,
+            self._nrows,
+            self._ncols,
             figure=mpl.gcf(),
         )
 
     def create_subplot(self, n, shared_axis=None):
         i, j = numpy.unravel_index(n, (self._nrows, self._ncols))
-        i1 = i * self._row_times + self._grid_axis_offset
-        i2 = (i + 1) * self._row_times + self._grid_axis_offset
-        j1, j2 = j * self._col_times, (j + 1) * self._col_times
-        axis = mpl.gcf().add_subplot(self._gs[i1:i2, j1:j2], sharex=shared_axis)
+        axis = mpl.gcf().add_subplot(self._gs[i : i + 1, j : j + 1], sharex=shared_axis)
         return axis
 
     def finalize_one_page(self):
         # print legend on the page
         self.plot_legends()
+        fig = mpl.gcf()
+        axes = fig.get_axes()
+
+        LOGGER.debug("%s contains %d axes:", fig, len(axes))
+        for i, ax in enumerate(axes, start=1):
+            LOGGER.debug("Axes %d: %s", i, ax)
+
         self._pdf_page.savefig(bbox_inches="tight")
         mpl.clf()
         mpl.figure()
         self._create_grid_spec()
 
     def plot_legends(self):
-        """ Print legend on current page"""
+        """Print legend on current page"""
+        if not self._disp_legend:
+            return
+
         lines = []
         labels = []
         for ax in mpl.gcf().get_axes():
@@ -910,14 +906,27 @@ class GridSubplot(PlotBase):
                     lines.append(li)
                     labels.append(la)
 
-        if self._disp_legend:
-            # create legend on the top or bottom axis
-            ax = mpl.gcf().add_subplot(self._gs[self._legend_grid_axis_number, :])
-            # right, left, or center
-            loc = self._legend_loc.split()[1]
-            ax.legend(lines, labels, loc=loc, ncol=self._nlegends)
-            # don't show its axis
-            ax.set_axis_off()
+        # create legend on the top or bottom axis
+        fig = mpl.gcf()
+        if "upper" in self._legend_loc:
+            # Set anchor to top of figure
+            bbox_to_anchor = (0.0, 1.0, 1.0, 0.0)
+            # Legend will be anchored with its bottom side, so switch the loc
+            anchored_loc = self._legend_loc.replace("upper", "lower")
+        else:
+            # Set anchor to bottom of figure
+            bbox_to_anchor = (0.0, 0.0, 1.0, 0.0)
+            # Legend will be anchored with its top side, so switch the loc
+            anchored_loc = self._legend_loc.replace("lower", "upper")
+        leg = fig.legend(
+            lines,
+            labels,
+            loc=anchored_loc,
+            ncol=self._nlegends,
+            bbox_to_anchor=bbox_to_anchor,
+        )
+
+        return leg
 
 
 class Hist(GridSubplot):
@@ -971,7 +980,13 @@ class Hist(GridSubplot):
 
         if not self._hide_dev or not self._eval:
             dev_axis = self._print_subplot(
-                idx, sys, dev_neg, dev_pos, threshold, not self._no_line, False,
+                idx,
+                sys,
+                dev_neg,
+                dev_pos,
+                threshold,
+                not self._no_line,
+                False,
             )
 
         if self._eval:
@@ -1084,7 +1099,7 @@ class Hist(GridSubplot):
         mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
 
     def _setup_hist(self, neg, pos):
-        """ This function can be overwritten in derived classes
+        """This function can be overwritten in derived classes
 
         Plots all the density histo required in one plot. Here negative and
         positive scores densities.