Skip to content
Snippets Groups Projects
Commit cc138960 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add a base class for subplot-based plots

parent 917ab50f
Branches
Tags
1 merge request!89Add a base class for subplot-based plots
Pipeline #
......@@ -4,11 +4,11 @@ from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
import math
import sys
import os.path
import numpy
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib import gridspec
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
from .. import (far_threshold, plot, utils, ppndf)
......@@ -707,7 +707,93 @@ class Epc(PlotBase):
)
class Hist(PlotBase):
class GridSubplot(PlotBase):
"""A base class for plots that contain subplots and legends. This is needed
because "constrained_layout will not work on subplots created via the
subplot command."
To use this class, use `create_subplot` in `compute` each time you need a
new axis. and call `finalize_one_page` in `compute` when a page is finished
rendering.
"""
def __init__(self, ctx, scores, evaluation, func_load):
super(GridSubplot, self).__init__(ctx, scores, evaluation, func_load)
# Check legend
self._legend_loc = self._legend_loc or 'upper center'
if self._legend_loc == 'best':
self._legend_loc = 'upper center'
if 'upper' not in self._legend_loc and \
'lower' not in self._legend_loc:
raise ValueError('Only best, (upper *), and (lower-*) legend '
'locations are supported!')
if 'up' in self._legend_loc:
self._legend_grid_axis_number = 0
self._grid_axis_offset = 1
else:
self._legend_grid_axis_number = -1
self._grid_axis_offset = 0
# subplot grid
self._nrows = ctx.meta.get('n_row', 1)
self._ncols = ctx.meta.get('n_col', 1)
# GridSpec artificial rows and cols multipliers
self._row_times = 8
self._col_times = 2
def init_process(self):
super(GridSubplot, self).init_process()
self._create_grid_spec()
def _create_grid_spec(self):
# create a compatible GridSpec
self._gs = gridspec.GridSpec(
self._nrows * self._row_times + 1,
self._ncols * self._col_times,
figure=mpl.gcf())
def create_subplot(self, n):
i, j = numpy.unravel_index(n, (self._nrows, self._ncols))
i1 = i * self._row_times + self._grid_axis_offset
i2 = (i + 1) * self._row_times + self._grid_axis_offset
j1, j2 = j * self._col_times, (j + 1) * self._col_times
axis = mpl.gcf().add_subplot(self._gs[i1:i2, j1:j2])
return axis
def finalize_one_page(self):
# print legend on the page
self.plot_legends()
self._pdf_page.savefig(bbox_inches='tight')
mpl.clf()
mpl.figure()
self._create_grid_spec()
def plot_legends(self):
''' Print legend on current page'''
lines = []
labels = []
for ax in mpl.gcf().get_axes():
ali, ala = ax.get_legend_handles_labels()
# avoid duplicates in legend
for li, la in zip(ali, ala):
if la not in labels:
lines.append(li)
labels.append(la)
if self._disp_legend:
# create legend on the top or bottom axis
ax = mpl.gcf().add_subplot(
self._gs[self._legend_grid_axis_number, :])
# right, left, or center
loc = self._legend_loc.split()[1]
ax.legend(lines, labels, loc=loc, ncol=self._nlegends)
# don't show its axis
ax.set_axis_off()
class Hist(GridSubplot):
''' Functional base class for histograms'''
def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2):
......@@ -723,9 +809,6 @@ class Hist(PlotBase):
self._criterion = ctx.meta.get('criterion')
# no vertical (threshold) is displayed
self._no_line = ctx.meta.get('no_line', False)
# subplot grid
self._nrows = ctx.meta.get('n_row', 1)
self._ncols = ctx.meta.get('n_col', 1)
# do not display dev histo
self._hide_dev = ctx.meta.get('hide_dev', False)
if self._hide_dev and not self._eval:
......@@ -734,7 +817,7 @@ class Hist(PlotBase):
# dev hist are displayed next to eval hist
self._nrows *= 1 if self._hide_dev or not self._eval else 2
self._nlegends = ctx.meta.get('legends_ncol', 3)
self._legend_loc = self._legend_loc or 'upper center'
# number of subplot on one page
self._step_print = int(self._nrows * self._ncols)
self._title_base = 'Scores'
......@@ -766,12 +849,13 @@ class Hist(PlotBase):
self._print_subplot(idx, sys, eval_neg, eval_pos, threshold,
not self._no_line, True)
def _print_subplot(self, idx, sys, neg, pos, threshold, draw_line, evaluation):
def _print_subplot(self, idx, sys, neg, pos, threshold, draw_line,
evaluation):
''' print a subplot for the given score and subplot index'''
n = idx % self._step_print
col = n % self._ncols
sub_plot_idx = n + 1
axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx)
axis = self.create_subplot(n)
self._setup_hist(neg, pos)
if col == 0:
axis.set_ylabel(self._y_label)
......@@ -805,12 +889,7 @@ class Hist(PlotBase):
# to display, save figure
if self._step_print == sub_plot_idx or (is_lower and sys ==
self.n_systems - 1):
# print legend on the page
self.plot_legends()
mpl.tight_layout()
self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
mpl.clf()
mpl.figure()
self.finalize_one_page()
def _get_title(self, idx, dflt=None):
''' Get the histo title for the given idx'''
......@@ -821,25 +900,6 @@ class Hist(PlotBase):
' ', '') else title
return title or ''
def plot_legends(self):
''' Print legend on current page'''
lines = []
labels = []
for ax in mpl.gcf().get_axes():
ali, ala = ax.get_legend_handles_labels()
# avoid duplicates in legend
for li, la in zip(ali, ala):
if la not in labels:
lines.append(li)
labels.append(la)
if self._disp_legend:
mpl.gcf().legend(
lines, labels, loc=self._legend_loc, fancybox=True,
framealpha=0.5, ncol=self._nlegends,
bbox_to_anchor=(0.55, 1.1),
)
def _get_neg_pos_thres(self, idx, input_scores, input_names):
''' Get scores and threshod for the given system at index idx'''
neg_list, pos_list, _ = utils.get_fta_list(input_scores)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment