Commit 4cdead48 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Merge branch 'fixlegends' into 'master'

Fix issue with histo legends

Closes #47

See merge request !75
parents 69db6c71 255c4b4d
Pipeline #21248 passed with stages
in 23 minutes and 3 seconds
......@@ -187,7 +187,7 @@ def x_rotation_option(dflt=0, **kwargs):
return custom_x_rotation_option
def legend_ncols_option(dflt=10, **kwargs):
def legend_ncols_option(dflt=3, **kwargs):
'''Get option for number of columns for legends'''
def custom_legend_ncols_option(func):
def callback(ctx, param, value):
......
......@@ -571,14 +571,18 @@ class Hist(PlotBase):
self._thres = check_list_value(
self._thres, self.n_systems, 'thresholds')
self._criterion = ctx.meta.get('criterion')
# no vertical (threshold) is displayed
self._no_line = ctx.meta.get('no_line', False)
# subplot grid
self._nrows = ctx.meta.get('n_row', 1)
self._ncols = ctx.meta.get('n_col', 1)
# do not display dev histo
self._hide_dev = ctx.meta.get('hide_dev', False)
# dev hist are displayed next to eval hist
self._ncols *= 1 if self._hide_dev else 2
self._nlegends = ctx.meta.get('legends_ncol', 10)
self._nlegends = ctx.meta.get('legends_ncol', 3)
self._legend_loc = self._legend_loc or 'upper center'
# number of subplot on one page
self._step_print = int(self._nrows * self._ncols)
self._title_base = 'Scores'
self._y_label = 'Probability density'
......@@ -586,6 +590,7 @@ class Hist(PlotBase):
self._end_setup_plot = False
if self._legends is not None and len(self._legends) == self.n_systems \
and not self._hide_dev:
# use same legend for dev and eval if needed
self._legends = [x for pair in zip(self._legends,self._legends)
for x in pair]
......@@ -605,6 +610,7 @@ class Hist(PlotBase):
not self._no_line, dflt_title="Eval scores")
def _print_subplot(self, idx, neg, pos, threshold, draw_line, dflt_title):
''' print a subplot for the given score and subplot index'''
n = idx % self._step_print
col = n % self._ncols
sub_plot_idx = n + 1
......@@ -624,16 +630,20 @@ class Hist(PlotBase):
)
if draw_line:
self._lines(threshold, label, neg, pos, idx)
if sub_plot_idx == 1:
self._plot_legends()
mult = 2 if self._eval and not self._hide_dev else 1
# if it was the last subplot of the page or the last subplot
# to display, save figure
if self._step_print == sub_plot_idx or idx == self.n_systems * mult - 1:
# print legend on the page
self.plot_legends()
mpl.tight_layout()
self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
mpl.clf()
mpl.figure()
def _get_title(self, idx, dflt=None):
''' Get the histo title for the given idx'''
title = self._legends[idx] if self._legends is not None \
and idx < len(self._legends) else dflt
title = title or self._title_base
......@@ -641,21 +651,27 @@ class Hist(PlotBase):
' ', '') else title
return title or ''
def _plot_legends(self):
def plot_legends(self):
''' Print legend on current page'''
lines = []
labels = []
for ax in mpl.gcf().get_axes():
li, la = ax.get_legend_handles_labels()
lines += li
labels += la
ali, ala = ax.get_legend_handles_labels()
# avoid duplicates in legend
for li, la in zip(ali, ala):
if la not in labels:
lines.append(li)
labels.append(la)
if self._disp_legend:
mpl.gcf().legend(
lines, labels, loc=self._legend_loc, fancybox=True,
framealpha=0.5, ncol=self._nlegends,
bbox_to_anchor=(0.55, 1.06),
bbox_to_anchor=(0.55, 1.1),
)
def _get_neg_pos_thres(self, idx, input_scores, input_names):
''' Get scores and threshod for the given system at index idx'''
neg_list, pos_list, _ = utils.get_fta_list(input_scores)
length = len(neg_list)
# can have several files for one system
......@@ -672,6 +688,7 @@ class Hist(PlotBase):
return dev_neg, dev_pos, eval_neg, eval_pos, threshold
def _density_hist(self, scores, n, **kwargs):
''' Plots one density histo'''
n, bins, patches = mpl.hist(
scores, density=True,
bins=self._nbins[n],
......@@ -681,6 +698,7 @@ class Hist(PlotBase):
def _lines(self, threshold, label=None, neg=None, pos=None,
idx=None, **kwargs):
''' Plots vertical line at threshold '''
label = label or 'Threshold'
kwargs.setdefault('color', 'C3')
kwargs.setdefault('linestyle', '--')
......@@ -689,7 +707,11 @@ class Hist(PlotBase):
mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
def _setup_hist(self, neg, pos):
''' This function can be overwritten in derived classes'''
''' This function can be overwritten in derived classes
Plots all the density histo required in one plot. Here negative and
positive scores densities.
'''
self._density_hist(
neg[0], n=0,
label='Negatives', alpha=0.5, color='C3'
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment