Commit 4cdead48 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Merge branch 'fixlegends' into 'master'

Fix issue with histo legends

Closes #47

See merge request !75
parents 69db6c71 255c4b4d
Pipeline #21248 passed with stages
in 23 minutes and 3 seconds
...@@ -187,7 +187,7 @@ def x_rotation_option(dflt=0, **kwargs): ...@@ -187,7 +187,7 @@ def x_rotation_option(dflt=0, **kwargs):
return custom_x_rotation_option return custom_x_rotation_option
def legend_ncols_option(dflt=10, **kwargs): def legend_ncols_option(dflt=3, **kwargs):
'''Get option for number of columns for legends''' '''Get option for number of columns for legends'''
def custom_legend_ncols_option(func): def custom_legend_ncols_option(func):
def callback(ctx, param, value): def callback(ctx, param, value):
......
...@@ -571,14 +571,18 @@ class Hist(PlotBase): ...@@ -571,14 +571,18 @@ class Hist(PlotBase):
self._thres = check_list_value( self._thres = check_list_value(
self._thres, self.n_systems, 'thresholds') self._thres, self.n_systems, 'thresholds')
self._criterion = ctx.meta.get('criterion') self._criterion = ctx.meta.get('criterion')
# no vertical (threshold) is displayed
self._no_line = ctx.meta.get('no_line', False) self._no_line = ctx.meta.get('no_line', False)
# subplot grid
self._nrows = ctx.meta.get('n_row', 1) self._nrows = ctx.meta.get('n_row', 1)
self._ncols = ctx.meta.get('n_col', 1) self._ncols = ctx.meta.get('n_col', 1)
# do not display dev histo
self._hide_dev = ctx.meta.get('hide_dev', False) self._hide_dev = ctx.meta.get('hide_dev', False)
# dev hist are displayed next to eval hist # dev hist are displayed next to eval hist
self._ncols *= 1 if self._hide_dev else 2 self._ncols *= 1 if self._hide_dev else 2
self._nlegends = ctx.meta.get('legends_ncol', 10) self._nlegends = ctx.meta.get('legends_ncol', 3)
self._legend_loc = self._legend_loc or 'upper center' self._legend_loc = self._legend_loc or 'upper center'
# number of subplot on one page
self._step_print = int(self._nrows * self._ncols) self._step_print = int(self._nrows * self._ncols)
self._title_base = 'Scores' self._title_base = 'Scores'
self._y_label = 'Probability density' self._y_label = 'Probability density'
...@@ -586,6 +590,7 @@ class Hist(PlotBase): ...@@ -586,6 +590,7 @@ class Hist(PlotBase):
self._end_setup_plot = False self._end_setup_plot = False
if self._legends is not None and len(self._legends) == self.n_systems \ if self._legends is not None and len(self._legends) == self.n_systems \
and not self._hide_dev: and not self._hide_dev:
# use same legend for dev and eval if needed
self._legends = [x for pair in zip(self._legends,self._legends) self._legends = [x for pair in zip(self._legends,self._legends)
for x in pair] for x in pair]
...@@ -605,6 +610,7 @@ class Hist(PlotBase): ...@@ -605,6 +610,7 @@ class Hist(PlotBase):
not self._no_line, dflt_title="Eval scores") not self._no_line, dflt_title="Eval scores")
def _print_subplot(self, idx, neg, pos, threshold, draw_line, dflt_title): def _print_subplot(self, idx, neg, pos, threshold, draw_line, dflt_title):
''' print a subplot for the given score and subplot index'''
n = idx % self._step_print n = idx % self._step_print
col = n % self._ncols col = n % self._ncols
sub_plot_idx = n + 1 sub_plot_idx = n + 1
...@@ -624,16 +630,20 @@ class Hist(PlotBase): ...@@ -624,16 +630,20 @@ class Hist(PlotBase):
) )
if draw_line: if draw_line:
self._lines(threshold, label, neg, pos, idx) self._lines(threshold, label, neg, pos, idx)
if sub_plot_idx == 1:
self._plot_legends()
mult = 2 if self._eval and not self._hide_dev else 1 mult = 2 if self._eval and not self._hide_dev else 1
# if it was the last subplot of the page or the last subplot
# to display, save figure
if self._step_print == sub_plot_idx or idx == self.n_systems * mult - 1: if self._step_print == sub_plot_idx or idx == self.n_systems * mult - 1:
# print legend on the page
self.plot_legends()
mpl.tight_layout() mpl.tight_layout()
self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight') self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
mpl.clf() mpl.clf()
mpl.figure() mpl.figure()
def _get_title(self, idx, dflt=None): def _get_title(self, idx, dflt=None):
''' Get the histo title for the given idx'''
title = self._legends[idx] if self._legends is not None \ title = self._legends[idx] if self._legends is not None \
and idx < len(self._legends) else dflt and idx < len(self._legends) else dflt
title = title or self._title_base title = title or self._title_base
...@@ -641,21 +651,27 @@ class Hist(PlotBase): ...@@ -641,21 +651,27 @@ class Hist(PlotBase):
' ', '') else title ' ', '') else title
return title or '' return title or ''
def _plot_legends(self): def plot_legends(self):
''' Print legend on current page'''
lines = [] lines = []
labels = [] labels = []
for ax in mpl.gcf().get_axes(): for ax in mpl.gcf().get_axes():
li, la = ax.get_legend_handles_labels() ali, ala = ax.get_legend_handles_labels()
lines += li # avoid duplicates in legend
labels += la for li, la in zip(ali, ala):
if la not in labels:
lines.append(li)
labels.append(la)
if self._disp_legend: if self._disp_legend:
mpl.gcf().legend( mpl.gcf().legend(
lines, labels, loc=self._legend_loc, fancybox=True, lines, labels, loc=self._legend_loc, fancybox=True,
framealpha=0.5, ncol=self._nlegends, framealpha=0.5, ncol=self._nlegends,
bbox_to_anchor=(0.55, 1.06), bbox_to_anchor=(0.55, 1.1),
) )
def _get_neg_pos_thres(self, idx, input_scores, input_names): def _get_neg_pos_thres(self, idx, input_scores, input_names):
''' Get scores and threshod for the given system at index idx'''
neg_list, pos_list, _ = utils.get_fta_list(input_scores) neg_list, pos_list, _ = utils.get_fta_list(input_scores)
length = len(neg_list) length = len(neg_list)
# can have several files for one system # can have several files for one system
...@@ -672,6 +688,7 @@ class Hist(PlotBase): ...@@ -672,6 +688,7 @@ class Hist(PlotBase):
return dev_neg, dev_pos, eval_neg, eval_pos, threshold return dev_neg, dev_pos, eval_neg, eval_pos, threshold
def _density_hist(self, scores, n, **kwargs): def _density_hist(self, scores, n, **kwargs):
''' Plots one density histo'''
n, bins, patches = mpl.hist( n, bins, patches = mpl.hist(
scores, density=True, scores, density=True,
bins=self._nbins[n], bins=self._nbins[n],
...@@ -681,6 +698,7 @@ class Hist(PlotBase): ...@@ -681,6 +698,7 @@ class Hist(PlotBase):
def _lines(self, threshold, label=None, neg=None, pos=None, def _lines(self, threshold, label=None, neg=None, pos=None,
idx=None, **kwargs): idx=None, **kwargs):
''' Plots vertical line at threshold '''
label = label or 'Threshold' label = label or 'Threshold'
kwargs.setdefault('color', 'C3') kwargs.setdefault('color', 'C3')
kwargs.setdefault('linestyle', '--') kwargs.setdefault('linestyle', '--')
...@@ -689,7 +707,11 @@ class Hist(PlotBase): ...@@ -689,7 +707,11 @@ class Hist(PlotBase):
mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs) mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
def _setup_hist(self, neg, pos): def _setup_hist(self, neg, pos):
''' This function can be overwritten in derived classes''' ''' This function can be overwritten in derived classes
Plots all the density histo required in one plot. Here negative and
positive scores densities.
'''
self._density_hist( self._density_hist(
neg[0], n=0, neg[0], n=0,
label='Negatives', alpha=0.5, color='C3' label='Negatives', alpha=0.5, color='C3'
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment