Commit 18ab16a4 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Enable semilogx option in roc curves

parent b454ba50
Pipeline #20254 passed with stage
in 21 minutes and 19 seconds
......@@ -36,6 +36,15 @@ def log_values(min_step=-4, counts_per_step=4):
return [math.pow(10., i * 1. / counts_per_step) for i in range(min_step * counts_per_step, 0)] + [1.]
def _semilogx(x, y, **kwargs):
# remove points were x is 0
zero_index = x == 0
x = x[~zero_index]
y = y[~zero_index]
from matplotlib import pyplot
return pyplot.semilogx(x, y, **kwargs)
def roc(negatives, positives, npoints=100, CAR=False, **kwargs):
"""Plots Receiver Operating Characteristic (ROC) curve.
......@@ -91,10 +100,11 @@ def roc(negatives, positives, npoints=100, CAR=False, **kwargs):
if not CAR:
return pyplot.plot(out[0, :], out[1, :], **kwargs)
else:
return pyplot.semilogx(out[0, :],(1 - out[1, :]), **kwargs)
return _semilogx(out[0, :], (1 - out[1, :]), **kwargs)
def roc_for_far(negatives, positives, far_values=log_values(), **kwargs):
def roc_for_far(negatives, positives, far_values=log_values(), CAR=True,
**kwargs):
"""Plots the ROC curve for the given list of False Acceptance Rates (FAR).
This method will call ``matplotlib`` to plot the ROC curve for a system which
......@@ -127,6 +137,10 @@ def roc_for_far(negatives, positives, far_values=log_values(), **kwargs):
far_values (:py:class:`list`, optional): The values for the FAR, where the
CAR should be plotted; each value should be in range [0,1].
CAR (:py:class:`bool`, optional): If set to ``True``, it will plot the CAR
over FAR in using :py:func:`matplotlib.pyplot.semilogx`, otherwise the
FAR over FRR linearly using :py:func:`matplotlib.pyplot.plot`.
kwargs (:py:class:`dict`, optional): Extra plotting parameters, which are
passed directly to :py:func:`matplotlib.pyplot.plot`.
......@@ -142,7 +156,10 @@ def roc_for_far(negatives, positives, far_values=log_values(), **kwargs):
from matplotlib import pyplot
from . import roc_for_far as calc
out = calc(negatives, positives, far_values)
return pyplot.semilogx(out[0, :], (1 - out[1, :]), **kwargs)
if not CAR:
return pyplot.plot(out[0, :], out[1, :], **kwargs)
else:
return _semilogx(out[0, :], (1 - out[1, :]), **kwargs)
def precision_recall_curve(negatives, positives, npoints=100, **kwargs):
......@@ -453,7 +470,7 @@ def det_axis(v, **kwargs):
def cmc(cmc_scores, logx=True, **kwargs):
"""Plots the (cumulative) match characteristics and returns the maximum rank.
This function plots a CMC curve using the given CMC scores (:py:class:`list`:
This function plots a CMC curve using the given CMC scores (:py:class:`list`:
A list of tuples, where each tuple contains the
``negative`` and ``positive`` scores for one probe of the database).
......@@ -483,7 +500,7 @@ def cmc(cmc_scores, logx=True, **kwargs):
out = calc(cmc_scores)
if logx:
pyplot.semilogx(range(1, len(out) + 1), out, **kwargs)
_semilogx(range(1, len(out) + 1), out, **kwargs)
else:
pyplot.plot(range(1, len(out) + 1), out, **kwargs)
......@@ -557,6 +574,6 @@ def detection_identification_curve(cmc_scores, far_values=log_values(), rank=1,
# plot curve
if logx:
return pyplot.semilogx(far_values, rates, **kwargs)
return _semilogx(far_values, rates, **kwargs)
else:
return pyplot.plot(far_values, rates, **kwargs)
......@@ -47,12 +47,12 @@ def metrics(ctx, scores, evaluation, **kwargs):
@common_options.title_option()
@common_options.legends_option()
@common_options.no_legend_option()
@common_options.legend_loc_option(dflt='lower-right')
@common_options.legend_loc_option(dflt=None)
@common_options.sep_dev_eval_option()
@common_options.output_plot_file_option(default_out='roc.pdf')
@common_options.eval_option()
@common_options.points_curve_option()
@common_options.axes_val_option(dflt='1e-4,1,1e-4,1')
@common_options.axes_val_option()
@common_options.min_far_option()
@common_options.x_rotation_option()
@common_options.x_label_option()
......
......@@ -392,7 +392,7 @@ def legend_loc_option(dflt='best', **kwargs):
'''Get the legend location of the plot'''
def custom_legend_loc_option(func):
def callback(ctx, param, value):
ctx.meta['legend_loc'] = value.replace('-', ' ')
ctx.meta['legend_loc'] = value.replace('-', ' ') if value else value
return value
return click.option(
'-lc', '--legend-loc', default=dflt, show_default=True,
......
......@@ -40,12 +40,12 @@ class MeasureBase(object):
func_load : Function that is used to load the input files
"""
self._scores = scores
self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
self._min_arg = ctx.meta.get('min_arg', 1)
self._ctx = ctx
self.func_load = func_load
self._legends = None if 'legends' not in ctx.meta else ctx.meta['legends']
self._legends = ctx.meta.get('legends')
self._eval = evaluation
self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
self._min_arg = ctx.meta.get('min_arg', 1)
if len(scores) < 1 or len(scores) % self._min_arg != 0:
raise click.BadParameter(
'Number of argument must be a non-zero multiple of %d' % self._min_arg
......@@ -151,13 +151,10 @@ class Metrics(MeasureBase):
def __init__(self, ctx, scores, evaluation, func_load):
super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
self._tablefmt = None if 'tablefmt' not in ctx.meta else\
ctx.meta['tablefmt']
self._criterion = None if 'criterion' not in ctx.meta else \
ctx.meta['criterion']
self._open_mode = None if 'open_mode' not in ctx.meta else\
ctx.meta['open_mode']
self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
self._tablefmt = ctx.meta.get('tablefmt')
self._criterion = ctx.meta.get('criterion')
self._open_mode = ctx.meta.get('open_mode')
self._thres = ctx.meta.get('thres')
if self._thres is not None:
if len(self._thres) == 1:
self._thres = self._thres * self.n_systems
......@@ -166,9 +163,8 @@ class Metrics(MeasureBase):
'#thresholds must be the same as #systems (%d)'
% len(self.n_systems)
)
self._far = None if 'far_value' not in ctx.meta else \
ctx.meta['far_value']
self._log = None if 'log' not in ctx.meta else ctx.meta['log']
self._far = ctx.meta.get('far_value')
self._log = ctx.meta.get('log')
self.log_file = sys.stdout
if self._log is not None:
self.log_file = open(self._log, self._open_mode)
......@@ -271,46 +267,37 @@ class PlotBase(MeasureBase):
def __init__(self, ctx, scores, evaluation, func_load):
super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
self._output = None if 'output' not in ctx.meta else ctx.meta['output']
self._points = 100 if 'points' not in ctx.meta else ctx.meta['points']
self._split = None if 'split' not in ctx.meta else ctx.meta['split']
self._axlim = None if 'axlim' not in ctx.meta else ctx.meta['axlim']
self._disp_legend = True if 'disp_legend' not in ctx.meta else\
ctx.meta['disp_legend']
self._legend_loc = None if 'legend_loc' not in ctx.meta else\
ctx.meta['legend_loc']
self._output = ctx.meta.get('output')
self._points = ctx.meta.get('points', 100)
self._split = ctx.meta.get('split')
self._axlim = ctx.meta.get('axlim')
self._disp_legend = ctx.meta.get('disp_legend', True)
self._legend_loc = ctx.meta.get('legend_loc')
self._min_dig = None
if 'min_far_value' in ctx.meta:
self._min_dig = int(math.log10(ctx.meta['min_far_value']))
elif self._axlim is not None:
elif self._axlim is not None and self._axlim[0] is not None:
self._min_dig = int(math.log10(self._axlim[0])
if self._axlim[0] != 0 else 0)
self._clayout = None if 'clayout' not in ctx.meta else\
ctx.meta['clayout']
self._far_at = None if 'lines_at' not in ctx.meta else\
ctx.meta['lines_at']
self._clayout = ctx.meta.get('clayout')
self._far_at = ctx.meta.get('lines_at')
self._trans_far_val = self._far_at
if self._far_at is not None:
self._eval_points = {line: [] for line in self._far_at}
self._lines_val = []
self._print_fn = True if 'show_fn' not in ctx.meta else\
ctx.meta['show_fn']
self._x_rotation = None if 'x_rotation' not in ctx.meta else\
ctx.meta['x_rotation']
self._print_fn = ctx.meta.get('show_fn', True)
self._x_rotation = ctx.meta.get('x_rotation')
if 'style' in ctx.meta:
mpl.style.use(ctx.meta['style'])
self._nb_figs = 2 if self._eval and self._split else 1
self._colors = utils.get_colors(self.n_systems)
self._line_linestyles = False if 'line_linestyles' not in ctx.meta else \
ctx.meta['line_linestyles']
self._line_linestyles = ctx.meta.get('line_linestyles', False)
self._linestyles = utils.get_linestyles(
self.n_systems, self._line_linestyles)
self._states = ['Development', 'Evaluation']
self._title = None if 'title' not in ctx.meta else ctx.meta['title']
self._x_label = None if 'x_label' not in ctx.meta else\
ctx.meta['x_label']
self._y_label = None if 'y_label' not in ctx.meta else\
ctx.meta['y_label']
self._title = ctx.meta.get('title')
self._x_label = ctx.meta.get('x_label')
self._y_label = ctx.meta.get('y_label')
self._grid_color = 'silver'
self._pdf_page = None
self._end_setup_plot = True
......@@ -324,8 +311,7 @@ class PlotBase(MeasureBase):
self._ctx.meta else PdfPages(self._output)
for i in range(self._nb_figs):
fs = None if 'figsize' not in self._ctx.meta else\
self._ctx.meta['figsize']
fs = self._ctx.meta.get('figsize')
fig = mpl.figure(i + 1, figsize=fs)
fig.set_constrained_layout(self._clayout)
fig.clear()
......@@ -387,7 +373,7 @@ class PlotBase(MeasureBase):
return base + (" (%s)" % name)
def _set_axis(self):
if self._axlim is not None and None not in self._axlim:
if self._axlim is not None:
mpl.axis(self._axlim)
......@@ -399,13 +385,12 @@ class Roc(PlotBase):
self._title = self._title or 'ROC'
self._x_label = self._x_label or 'False Positive Rate'
self._y_label = self._y_label or "1 - False Negative Rate"
self._legend_loc = self._legend_loc or 'lower right'
self._semilogx = ctx.meta.get('semilogx', True)
best_legend = 'lower right' if self._semilogx else 'upper right'
self._legend_loc = self._legend_loc or best_legend
# custom defaults
if self._axlim is None:
self._axlim = [1e-4, 1.0, 0, 1.0]
if self._min_dig is not None:
self._axlim[0] = math.pow(10, self._min_dig)
self._axlim = [None, None, -0.05, 1.05]
def compute(self, idx, input_scores, input_names):
''' Plot ROC for dev and eval data using
......@@ -422,6 +407,7 @@ class Roc(PlotBase):
plot.roc_for_far(
dev_neg, dev_pos,
far_values=plot.log_values(self._min_dig or -4),
CAR=self._semilogx,
color=self._colors[idx], linestyle=self._linestyles[idx],
label=self._label('development', dev_file, idx)
)
......@@ -432,6 +418,7 @@ class Roc(PlotBase):
plot.roc_for_far(
eval_neg, eval_pos, linestyle=linestyle,
far_values=plot.log_values(self._min_dig or -4),
CAR=self._semilogx,
color=self._colors[idx],
label=self._label('eval', eval_file, idx)
)
......@@ -448,6 +435,7 @@ class Roc(PlotBase):
plot.roc_for_far(
dev_neg, dev_pos,
far_values=plot.log_values(self._min_dig or -4),
CAR=self._semilogx,
color=self._colors[idx], linestyle=self._linestyles[idx],
label=self._label('development', dev_file, idx)
)
......@@ -558,8 +546,8 @@ class Hist(PlotBase):
def __init__(self, ctx, scores, evaluation, func_load):
super(Hist, self).__init__(ctx, scores, evaluation, func_load)
self._nbins = [] if 'n_bins' not in ctx.meta else ctx.meta['n_bins']
self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
self._nbins = ctx.meta.get('n_bins', [])
self._thres = ctx.meta.get('thres')
if self._thres is not None and len(self._thres) != self.n_systems:
if len(self._thres) == 1:
self._thres = self._thres * self.n_systems
......@@ -568,12 +556,10 @@ class Hist(PlotBase):
'#thresholds must be the same as #systems (%d)'
% self.n_systems
)
self._criterion = None if 'criterion' not in ctx.meta else \
ctx.meta['criterion']
self._nrows = 1 if 'n_row' not in ctx.meta else ctx.meta['n_row']
self._ncols = 1 if 'n_col' not in ctx.meta else ctx.meta['n_col']
self._nlegends = 10 if 'legends_ncol' not in ctx.meta else \
ctx.meta['legends_ncol']
self._criterion = ctx.meta.get('criterion')
self._nrows = ctx.meta.get('n_row', 1)
self._ncols = ctx.meta.get('n_col', 1)
self._nlegends = ctx.meta.get('legends_ncol', 10)
self._legend_loc = self._legend_loc or 'upper center'
self._step_print = int(self._nrows * self._ncols)
self._title_base = 'Scores'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment