Skip to content
Snippets Groups Projects
Commit 255c4b4d authored by Theophile GENTILHOMME's avatar Theophile GENTILHOMME
Browse files

[script][figure] Fix bug with histo legends and add comments/doc in the code

parent 4095e2ac
Branches
Tags
1 merge request!75Fix issue with histo legends
Pipeline #
...@@ -571,14 +571,18 @@ class Hist(PlotBase): ...@@ -571,14 +571,18 @@ class Hist(PlotBase):
self._thres = check_list_value( self._thres = check_list_value(
self._thres, self.n_systems, 'thresholds') self._thres, self.n_systems, 'thresholds')
self._criterion = ctx.meta.get('criterion') self._criterion = ctx.meta.get('criterion')
# no vertical (threshold) is displayed
self._no_line = ctx.meta.get('no_line', False) self._no_line = ctx.meta.get('no_line', False)
# subplot grid
self._nrows = ctx.meta.get('n_row', 1) self._nrows = ctx.meta.get('n_row', 1)
self._ncols = ctx.meta.get('n_col', 1) self._ncols = ctx.meta.get('n_col', 1)
# do not display dev histo
self._hide_dev = ctx.meta.get('hide_dev', False) self._hide_dev = ctx.meta.get('hide_dev', False)
# dev hist are displayed next to eval hist # dev hist are displayed next to eval hist
self._ncols *= 1 if self._hide_dev else 2 self._ncols *= 1 if self._hide_dev else 2
self._nlegends = ctx.meta.get('legends_ncol', 10) self._nlegends = ctx.meta.get('legends_ncol', 3)
self._legend_loc = self._legend_loc or 'upper center' self._legend_loc = self._legend_loc or 'upper center'
# number of subplot on one page
self._step_print = int(self._nrows * self._ncols) self._step_print = int(self._nrows * self._ncols)
self._title_base = 'Scores' self._title_base = 'Scores'
self._y_label = 'Probability density' self._y_label = 'Probability density'
...@@ -586,6 +590,7 @@ class Hist(PlotBase): ...@@ -586,6 +590,7 @@ class Hist(PlotBase):
self._end_setup_plot = False self._end_setup_plot = False
if self._legends is not None and len(self._legends) == self.n_systems \ if self._legends is not None and len(self._legends) == self.n_systems \
and not self._hide_dev: and not self._hide_dev:
# use same legend for dev and eval if needed
self._legends = [x for pair in zip(self._legends,self._legends) self._legends = [x for pair in zip(self._legends,self._legends)
for x in pair] for x in pair]
...@@ -605,6 +610,7 @@ class Hist(PlotBase): ...@@ -605,6 +610,7 @@ class Hist(PlotBase):
not self._no_line, dflt_title="Eval scores") not self._no_line, dflt_title="Eval scores")
def _print_subplot(self, idx, neg, pos, threshold, draw_line, dflt_title): def _print_subplot(self, idx, neg, pos, threshold, draw_line, dflt_title):
''' print a subplot for the given score and subplot index'''
n = idx % self._step_print n = idx % self._step_print
col = n % self._ncols col = n % self._ncols
sub_plot_idx = n + 1 sub_plot_idx = n + 1
...@@ -624,16 +630,20 @@ class Hist(PlotBase): ...@@ -624,16 +630,20 @@ class Hist(PlotBase):
) )
if draw_line: if draw_line:
self._lines(threshold, label, neg, pos, idx) self._lines(threshold, label, neg, pos, idx)
if sub_plot_idx == 1:
self._plot_legends()
mult = 2 if self._eval and not self._hide_dev else 1 mult = 2 if self._eval and not self._hide_dev else 1
# if it was the last subplot of the page or the last subplot
# to display, save figure
if self._step_print == sub_plot_idx or idx == self.n_systems * mult - 1: if self._step_print == sub_plot_idx or idx == self.n_systems * mult - 1:
# print legend on the page
self.plot_legends()
mpl.tight_layout() mpl.tight_layout()
self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight') self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
mpl.clf() mpl.clf()
mpl.figure() mpl.figure()
def _get_title(self, idx, dflt=None): def _get_title(self, idx, dflt=None):
''' Get the histo title for the given idx'''
title = self._legends[idx] if self._legends is not None \ title = self._legends[idx] if self._legends is not None \
and idx < len(self._legends) else dflt and idx < len(self._legends) else dflt
title = title or self._title_base title = title or self._title_base
...@@ -641,21 +651,27 @@ class Hist(PlotBase): ...@@ -641,21 +651,27 @@ class Hist(PlotBase):
' ', '') else title ' ', '') else title
return title or '' return title or ''
def _plot_legends(self): def plot_legends(self):
''' Print legend on current page'''
lines = [] lines = []
labels = [] labels = []
for ax in mpl.gcf().get_axes(): for ax in mpl.gcf().get_axes():
li, la = ax.get_legend_handles_labels() ali, ala = ax.get_legend_handles_labels()
lines += li # avoid duplicates in legend
labels += la for li, la in zip(ali, ala):
if la not in labels:
lines.append(li)
labels.append(la)
if self._disp_legend: if self._disp_legend:
mpl.gcf().legend( mpl.gcf().legend(
lines, labels, loc=self._legend_loc, fancybox=True, lines, labels, loc=self._legend_loc, fancybox=True,
framealpha=0.5, ncol=self._nlegends, framealpha=0.5, ncol=self._nlegends,
bbox_to_anchor=(0.55, 1.06), bbox_to_anchor=(0.55, 1.1),
) )
def _get_neg_pos_thres(self, idx, input_scores, input_names): def _get_neg_pos_thres(self, idx, input_scores, input_names):
''' Get scores and threshod for the given system at index idx'''
neg_list, pos_list, _ = utils.get_fta_list(input_scores) neg_list, pos_list, _ = utils.get_fta_list(input_scores)
length = len(neg_list) length = len(neg_list)
# can have several files for one system # can have several files for one system
...@@ -672,6 +688,7 @@ class Hist(PlotBase): ...@@ -672,6 +688,7 @@ class Hist(PlotBase):
return dev_neg, dev_pos, eval_neg, eval_pos, threshold return dev_neg, dev_pos, eval_neg, eval_pos, threshold
def _density_hist(self, scores, n, **kwargs): def _density_hist(self, scores, n, **kwargs):
''' Plots one density histo'''
n, bins, patches = mpl.hist( n, bins, patches = mpl.hist(
scores, density=True, scores, density=True,
bins=self._nbins[n], bins=self._nbins[n],
...@@ -681,6 +698,7 @@ class Hist(PlotBase): ...@@ -681,6 +698,7 @@ class Hist(PlotBase):
def _lines(self, threshold, label=None, neg=None, pos=None, def _lines(self, threshold, label=None, neg=None, pos=None,
idx=None, **kwargs): idx=None, **kwargs):
''' Plots vertical line at threshold '''
label = label or 'Threshold' label = label or 'Threshold'
kwargs.setdefault('color', 'C3') kwargs.setdefault('color', 'C3')
kwargs.setdefault('linestyle', '--') kwargs.setdefault('linestyle', '--')
...@@ -689,7 +707,11 @@ class Hist(PlotBase): ...@@ -689,7 +707,11 @@ class Hist(PlotBase):
mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs) mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
def _setup_hist(self, neg, pos): def _setup_hist(self, neg, pos):
''' This function can be overwritten in derived classes''' ''' This function can be overwritten in derived classes
Plots all the density histo required in one plot. Here negative and
positive scores densities.
'''
self._density_hist( self._density_hist(
neg[0], n=0, neg[0], n=0,
label='Negatives', alpha=0.5, color='C3' label='Negatives', alpha=0.5, color='C3'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment