Skip to content
Snippets Groups Projects
Commit cc138960 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add a base class for subplot-based plots

parent 917ab50f
No related branches found
No related tags found
1 merge request!89Add a base class for subplot-based plots
Pipeline #
...@@ -4,11 +4,11 @@ from __future__ import division, print_function ...@@ -4,11 +4,11 @@ from __future__ import division, print_function
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import math import math
import sys import sys
import os.path
import numpy import numpy
import click import click
import matplotlib import matplotlib
import matplotlib.pyplot as mpl import matplotlib.pyplot as mpl
from matplotlib import gridspec
from matplotlib.backends.backend_pdf import PdfPages from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate from tabulate import tabulate
from .. import (far_threshold, plot, utils, ppndf) from .. import (far_threshold, plot, utils, ppndf)
...@@ -707,7 +707,93 @@ class Epc(PlotBase): ...@@ -707,7 +707,93 @@ class Epc(PlotBase):
) )
class Hist(PlotBase): class GridSubplot(PlotBase):
"""A base class for plots that contain subplots and legends. This is needed
because "constrained_layout will not work on subplots created via the
subplot command."
To use this class, use `create_subplot` in `compute` each time you need a
new axis. and call `finalize_one_page` in `compute` when a page is finished
rendering.
"""
def __init__(self, ctx, scores, evaluation, func_load):
super(GridSubplot, self).__init__(ctx, scores, evaluation, func_load)
# Check legend
self._legend_loc = self._legend_loc or 'upper center'
if self._legend_loc == 'best':
self._legend_loc = 'upper center'
if 'upper' not in self._legend_loc and \
'lower' not in self._legend_loc:
raise ValueError('Only best, (upper *), and (lower-*) legend '
'locations are supported!')
if 'up' in self._legend_loc:
self._legend_grid_axis_number = 0
self._grid_axis_offset = 1
else:
self._legend_grid_axis_number = -1
self._grid_axis_offset = 0
# subplot grid
self._nrows = ctx.meta.get('n_row', 1)
self._ncols = ctx.meta.get('n_col', 1)
# GridSpec artificial rows and cols multipliers
self._row_times = 8
self._col_times = 2
def init_process(self):
super(GridSubplot, self).init_process()
self._create_grid_spec()
def _create_grid_spec(self):
# create a compatible GridSpec
self._gs = gridspec.GridSpec(
self._nrows * self._row_times + 1,
self._ncols * self._col_times,
figure=mpl.gcf())
def create_subplot(self, n):
i, j = numpy.unravel_index(n, (self._nrows, self._ncols))
i1 = i * self._row_times + self._grid_axis_offset
i2 = (i + 1) * self._row_times + self._grid_axis_offset
j1, j2 = j * self._col_times, (j + 1) * self._col_times
axis = mpl.gcf().add_subplot(self._gs[i1:i2, j1:j2])
return axis
def finalize_one_page(self):
# print legend on the page
self.plot_legends()
self._pdf_page.savefig(bbox_inches='tight')
mpl.clf()
mpl.figure()
self._create_grid_spec()
def plot_legends(self):
''' Print legend on current page'''
lines = []
labels = []
for ax in mpl.gcf().get_axes():
ali, ala = ax.get_legend_handles_labels()
# avoid duplicates in legend
for li, la in zip(ali, ala):
if la not in labels:
lines.append(li)
labels.append(la)
if self._disp_legend:
# create legend on the top or bottom axis
ax = mpl.gcf().add_subplot(
self._gs[self._legend_grid_axis_number, :])
# right, left, or center
loc = self._legend_loc.split()[1]
ax.legend(lines, labels, loc=loc, ncol=self._nlegends)
# don't show its axis
ax.set_axis_off()
class Hist(GridSubplot):
''' Functional base class for histograms''' ''' Functional base class for histograms'''
def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2): def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2):
...@@ -723,9 +809,6 @@ class Hist(PlotBase): ...@@ -723,9 +809,6 @@ class Hist(PlotBase):
self._criterion = ctx.meta.get('criterion') self._criterion = ctx.meta.get('criterion')
# no vertical (threshold) is displayed # no vertical (threshold) is displayed
self._no_line = ctx.meta.get('no_line', False) self._no_line = ctx.meta.get('no_line', False)
# subplot grid
self._nrows = ctx.meta.get('n_row', 1)
self._ncols = ctx.meta.get('n_col', 1)
# do not display dev histo # do not display dev histo
self._hide_dev = ctx.meta.get('hide_dev', False) self._hide_dev = ctx.meta.get('hide_dev', False)
if self._hide_dev and not self._eval: if self._hide_dev and not self._eval:
...@@ -734,7 +817,7 @@ class Hist(PlotBase): ...@@ -734,7 +817,7 @@ class Hist(PlotBase):
# dev hist are displayed next to eval hist # dev hist are displayed next to eval hist
self._nrows *= 1 if self._hide_dev or not self._eval else 2 self._nrows *= 1 if self._hide_dev or not self._eval else 2
self._nlegends = ctx.meta.get('legends_ncol', 3) self._nlegends = ctx.meta.get('legends_ncol', 3)
self._legend_loc = self._legend_loc or 'upper center'
# number of subplot on one page # number of subplot on one page
self._step_print = int(self._nrows * self._ncols) self._step_print = int(self._nrows * self._ncols)
self._title_base = 'Scores' self._title_base = 'Scores'
...@@ -766,12 +849,13 @@ class Hist(PlotBase): ...@@ -766,12 +849,13 @@ class Hist(PlotBase):
self._print_subplot(idx, sys, eval_neg, eval_pos, threshold, self._print_subplot(idx, sys, eval_neg, eval_pos, threshold,
not self._no_line, True) not self._no_line, True)
def _print_subplot(self, idx, sys, neg, pos, threshold, draw_line, evaluation): def _print_subplot(self, idx, sys, neg, pos, threshold, draw_line,
evaluation):
''' print a subplot for the given score and subplot index''' ''' print a subplot for the given score and subplot index'''
n = idx % self._step_print n = idx % self._step_print
col = n % self._ncols col = n % self._ncols
sub_plot_idx = n + 1 sub_plot_idx = n + 1
axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx) axis = self.create_subplot(n)
self._setup_hist(neg, pos) self._setup_hist(neg, pos)
if col == 0: if col == 0:
axis.set_ylabel(self._y_label) axis.set_ylabel(self._y_label)
...@@ -805,12 +889,7 @@ class Hist(PlotBase): ...@@ -805,12 +889,7 @@ class Hist(PlotBase):
# to display, save figure # to display, save figure
if self._step_print == sub_plot_idx or (is_lower and sys == if self._step_print == sub_plot_idx or (is_lower and sys ==
self.n_systems - 1): self.n_systems - 1):
# print legend on the page self.finalize_one_page()
self.plot_legends()
mpl.tight_layout()
self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
mpl.clf()
mpl.figure()
def _get_title(self, idx, dflt=None): def _get_title(self, idx, dflt=None):
''' Get the histo title for the given idx''' ''' Get the histo title for the given idx'''
...@@ -821,25 +900,6 @@ class Hist(PlotBase): ...@@ -821,25 +900,6 @@ class Hist(PlotBase):
' ', '') else title ' ', '') else title
return title or '' return title or ''
def plot_legends(self):
''' Print legend on current page'''
lines = []
labels = []
for ax in mpl.gcf().get_axes():
ali, ala = ax.get_legend_handles_labels()
# avoid duplicates in legend
for li, la in zip(ali, ala):
if la not in labels:
lines.append(li)
labels.append(la)
if self._disp_legend:
mpl.gcf().legend(
lines, labels, loc=self._legend_loc, fancybox=True,
framealpha=0.5, ncol=self._nlegends,
bbox_to_anchor=(0.55, 1.1),
)
def _get_neg_pos_thres(self, idx, input_scores, input_names): def _get_neg_pos_thres(self, idx, input_scores, input_names):
''' Get scores and threshod for the given system at index idx''' ''' Get scores and threshod for the given system at index idx'''
neg_list, pos_list, _ = utils.get_fta_list(input_scores) neg_list, pos_list, _ = utils.get_fta_list(input_scores)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment