figure.py 29 KB
Newer Older
1 2 3 4
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
5
import math
6
import sys
7
import os.path
8 9 10 11 12
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
13
from .. import (far_threshold, plot, utils, ppndf)
14

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
15

16 17 18 19 20 21 22 23 24 25 26 27
def check_list_value(values, desired_number, name, name2='systems'):
    if values is not None and len(values) != desired_number:
        if len(values) == 1:
            values = values * desired_number
        else:
            raise click.BadParameter(
                '#{} ({}) must be either 1 value or the same as '
                '#{} ({} values)'.format(name, values, name2, desired_number))

    return values


28 29 30 31 32 33 34 35 36 37
class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
38 39
    __metaclass__ = ABCMeta  # for python 2.7 compatibility

40
    def __init__(self, ctx, scores, evaluation, func_load):
41 42 43 44 45 46 47
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
48 49 50 51
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
52 53 54 55 56
        func_load : Function that is used to load the input files
        """
        self._scores = scores
        self._ctx = ctx
        self.func_load = func_load
57
        self._legends = ctx.meta.get('legends')
58
        self._eval = evaluation
59
        self._min_arg = ctx.meta.get('min_arg', 1)
60 61 62 63 64
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
65 66
        if self._legends is not None and len(self._legends) < self.n_systems:
            raise click.BadParameter("Number of legends must be >= to the "
67
                                     "number of systems")
68 69 70 71 72 73 74 75 76 77

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
78
        # init matplotlib, log files, ...
79
        self.init_process()
80 81
        # iterates through the different systems and feed `compute`
        # with the dev (and eval) scores of each system
82 83
        # Note that more than one dev or eval scores score can be passed to
        # each system
84
        for idx in range(self.n_systems):
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
85
            # load scores for each system: get the corresponding arrays and
86
            # base-name of files
87
            input_scores, input_names = self._load_files(
88 89 90 91 92 93
                # Scores are given as followed:
                # SysA-dev SysA-eval ... SysA-XX  SysB-dev SysB-eval ... SysB-XX
                # ------------------------------  ------------------------------
                #   First set of `self._min_arg`     Second set of input files
                #     input files starting at               for SysB
                #    index idx * self._min_arg
94
                self._scores[idx * self._min_arg:(idx + 1) * self._min_arg]
95 96
            )
            self.compute(idx, input_scores, input_names)
97
        # setup final configuration, plotting properties, ...
98 99
        self.end_process()

100
    # protected functions that need to be overwritten
101 102
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
103
        before iterating through the different systems.
104 105 106
        Should reimplemented in derived classes"""
        pass

107
    # Main computations are done here in the subclasses
108
    @abstractmethod
109
    def compute(self, idx, input_scores, input_names):
110
        """Compute metrics or plots from the given scores provided by
111 112 113 114 115 116 117
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
118 119 120 121
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
122 123 124
        """
        pass

125
    # Things to do after the main iterative computations are done
126 127
    @abstractmethod
    def end_process(self):
128
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
129
        after iterating through the different systems.
130
        Should reimplemented in derived classes"""
131 132
        pass

133
    # common protected functions
134

135 136
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
137 138 139

        Returns
        -------
140 141 142 143 144
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
145
        '''
146 147 148 149 150 151
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
152

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
153

154 155 156 157 158 159 160 161
class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
162

163 164
    def __init__(self, ctx, scores, evaluation, func_load,
                 names=('FtA', 'FMR', 'FNMR', 'FAR', 'FRR', 'HTER')):
165
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
166
        self.names = names
167 168 169 170
        self._tablefmt = ctx.meta.get('tablefmt')
        self._criterion = ctx.meta.get('criterion')
        self._open_mode = ctx.meta.get('open_mode')
        self._thres = ctx.meta.get('thres')
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
171
        if self._thres is not None:
172
            if len(self._thres) == 1:
173 174
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
175
                raise click.BadParameter(
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
176
                    '#thresholds must be the same as #systems (%d)'
177
                    % len(self.n_systems)
178
                )
179 180
        self._far = ctx.meta.get('far_value')
        self._log = ctx.meta.get('log')
181 182 183 184
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

185 186 187
    def get_thres(self, criterion, dev_neg, dev_pos, far):
        return utils.get_thres(criterion, dev_neg, dev_pos, far)

188
    def compute(self, idx, input_scores, input_names):
189
        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
190
        given system inputs'''
191 192 193 194 195 196 197
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

198
        threshold = self.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
199
            if self._thres is None else self._thres[idx]
200
        title = self._legends[idx] if self._legends is not None else None
201
        if self._thres is None:
202
            far_str = ''
203
            if self._criterion == 'far' and self._far is not None:
204
                far_str = str(self._far)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
205
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"
206 207 208
                       % (self._criterion.upper(),
                          far_str, title or dev_file,
                          threshold),
209 210
                       file=self.log_file)
        else:
211
            click.echo("[Min. criterion: user provided] Threshold on "
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
212
                       "Development set `%s`: %e"
213
                       % (dev_file or title, threshold), file=self.log_file)
214 215 216 217 218 219 220 221 222 223 224 225

        from .. import farfrr
        dev_fmr, dev_fnmr = farfrr(dev_neg, dev_pos, threshold)
        dev_far = dev_fmr * (1 - dev_fta)
        dev_frr = dev_fta + dev_fnmr * (1 - dev_fta)
        dev_hter = (dev_far + dev_frr) / 2.0

        dev_ni = dev_neg.shape[0]  # number of impostors
        dev_fm = int(round(dev_fmr * dev_ni))  # number of false accepts
        dev_nc = dev_pos.shape[0]  # number of clients
        dev_fnm = int(round(dev_fnmr * dev_nc))  # number of false rejects

226 227 228 229 230 231 232
        dev_fta_str = "%.1f%%" % (100 * dev_fta)
        dev_fmr_str = "%.1f%% (%d/%d)" % (100 * dev_fmr, dev_fm, dev_ni)
        dev_fnmr_str = "%.1f%% (%d/%d)" % (100 * dev_fnmr, dev_fnm, dev_nc)
        dev_far_str = "%.1f%%" % (100 * dev_far)
        dev_frr_str = "%.1f%%" % (100 * dev_frr)
        dev_hter_str = "%.1f%%" % (100 * dev_hter)
        headers = ['' or title, 'Development %s' % dev_file]
233 234 235 236 237 238
        raws = [[self.names[0], dev_fta_str],
                [self.names[1], dev_fmr_str],
                [self.names[2], dev_fnmr_str],
                [self.names[3], dev_far_str],
                [self.names[4], dev_frr_str],
                [self.names[5], dev_hter_str]]
239

240
        if self._eval:
241 242 243 244 245 246 247 248 249
            # computes statistics for the eval set based on the threshold a priori
            eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold)
            eval_far = eval_fmr * (1 - eval_fta)
            eval_frr = eval_fta + eval_fnmr * (1 - eval_fta)
            eval_hter = (eval_far + eval_frr) / 2.0

            eval_ni = eval_neg.shape[0]  # number of impostors
            eval_fm = int(round(eval_fmr * eval_ni))  # number of false accepts
            eval_nc = eval_pos.shape[0]  # number of clients
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
250 251
            # number of false rejects
            eval_fnm = int(round(eval_fnmr * eval_nc))
252 253

            eval_fta_str = "%.1f%%" % (100 * eval_fta)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
254 255 256 257
            eval_fmr_str = "%.1f%% (%d/%d)" % (100 *
                                               eval_fmr, eval_fm, eval_ni)
            eval_fnmr_str = "%.1f%% (%d/%d)" % (100 *
                                                eval_fnmr, eval_fnm, eval_nc)
258 259 260 261 262 263 264 265 266 267 268 269

            eval_far_str = "%.1f%%" % (100 * eval_far)
            eval_frr_str = "%.1f%%" % (100 * eval_frr)
            eval_hter_str = "%.1f%%" % (100 * eval_hter)

            headers.append('Eval. % s' % eval_file)
            raws[0].append(eval_fta_str)
            raws[1].append(eval_fmr_str)
            raws[2].append(eval_fnmr_str)
            raws[3].append(eval_far_str)
            raws[4].append(eval_frr_str)
            raws[5].append(eval_hter_str)
270 271 272 273 274 275 276 277

        click.echo(tabulate(raws, headers, self._tablefmt), file=self.log_file)

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
278

279 280 281 282
class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
283

284 285
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
286 287 288 289 290 291
        self._output = ctx.meta.get('output')
        self._points = ctx.meta.get('points', 100)
        self._split = ctx.meta.get('split')
        self._axlim = ctx.meta.get('axlim')
        self._disp_legend = ctx.meta.get('disp_legend', True)
        self._legend_loc = ctx.meta.get('legend_loc')
292 293 294
        self._min_dig = None
        if 'min_far_value' in ctx.meta:
            self._min_dig = int(math.log10(ctx.meta['min_far_value']))
295
        elif self._axlim is not None and self._axlim[0] is not None:
296 297
            self._min_dig = int(math.log10(self._axlim[0])
                                if self._axlim[0] != 0 else 0)
298 299
        self._clayout = ctx.meta.get('clayout')
        self._far_at = ctx.meta.get('lines_at')
300 301 302 303
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
304 305
        self._print_fn = ctx.meta.get('show_fn', True)
        self._x_rotation = ctx.meta.get('x_rotation')
306 307
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
308
        self._nb_figs = 2 if self._eval and self._split else 1
309
        self._colors = utils.get_colors(self.n_systems)
310
        self._line_linestyles = ctx.meta.get('line_linestyles', False)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
311 312
        self._linestyles = utils.get_linestyles(
            self.n_systems, self._line_linestyles)
313
        self._titles = ctx.meta.get('titles', []) * 2
314 315 316 317 318
        # for compatibility
        self._title = ctx.meta.get('title')
        if not self._titles and self._title is not None:
            self._titles = [self._title] * 2

319 320
        self._x_label = ctx.meta.get('x_label')
        self._y_label = ctx.meta.get('y_label')
321 322 323 324 325 326 327 328 329 330
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
331
            self._ctx.meta else PdfPages(self._output)
332

333
        for i in range(self._nb_figs):
334
            fs = self._ctx.meta.get('figsize')
335 336
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
337
            fig.clear()
338 339

    def end_process(self):
340 341
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
342
        # draw vertical lines
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
362
        # only for plots
363 364 365
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
366 367
                title = '' if not self._titles else self._titles[i]
                mpl.title(title if title.replace(' ', '') else '')
368 369 370
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
371 372
                if self._disp_legend:
                    mpl.legend(loc=self._legend_loc)
373
                self._set_axis()
374 375 376
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
377
        # do not want to close PDF when running evaluate
378 379 380 381
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
382
    # common protected functions
383 384

    def _label(self, base, name, idx):
385 386
        if self._legends is not None and len(self._legends) > idx:
            return self._legends[idx]
387
        if self.n_systems > 1:
388 389 390
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

391
    def _set_axis(self):
392
        if self._axlim is not None:
393
            mpl.axis(self._axlim)
394

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
395

396
class Roc(PlotBase):
397
    ''' Handles the plotting of ROC'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
398

399 400
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
401
        self._titles = self._titles or ['ROC dev', 'ROC eval']
402
        self._x_label = self._x_label or 'False Positive Rate'
403
        self._y_label = self._y_label or "1 - False Negative Rate"
404 405 406
        self._semilogx = ctx.meta.get('semilogx', True)
        best_legend = 'lower right' if self._semilogx else 'upper right'
        self._legend_loc = self._legend_loc or best_legend
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
407
        # custom defaults
408
        if self._axlim is None:
409
            self._axlim = [None, None, -0.05, 1.05]
410

411
    def compute(self, idx, input_scores, input_names):
412 413
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
414 415
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
416 417
        dev_file = input_names[0]
        if self._eval:
418
            eval_neg, eval_pos = neg_list[1], pos_list[1]
419 420
            eval_file = input_names[1]

421
        mpl.figure(1)
422
        if self._eval:
423 424
            plot.roc_for_far(
                dev_neg, dev_pos,
425
                far_values=plot.log_values(self._min_dig or -4),
426
                CAR=self._semilogx,
427
                color=self._colors[idx], linestyle=self._linestyles[idx],
428
                label=self._label('dev', dev_file, idx)
429 430 431 432
            )
            if self._split:
                mpl.figure(2)

433
            linestyle = '--' if not self._split else self._linestyles[idx]
434
            plot.roc_for_far(
435 436
                eval_neg, eval_pos, linestyle=linestyle,
                far_values=plot.log_values(self._min_dig or -4),
437
                CAR=self._semilogx,
438
                color=self._colors[idx],
439
                label=self._label('eval', eval_file, idx)
440
            )
441
            if self._far_at is not None:
442
                from .. import farfrr
443
                for line in self._far_at:
444
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
445 446
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
447
                    eval_fnmr = 1 - eval_fnmr
448 449
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
450
        else:
451 452
            plot.roc_for_far(
                dev_neg, dev_pos,
453
                far_values=plot.log_values(self._min_dig or -4),
454
                CAR=self._semilogx,
455
                color=self._colors[idx], linestyle=self._linestyles[idx],
456
                label=self._label('dev', dev_file, idx)
457 458
            )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
459

460 461
class Det(PlotBase):
    ''' Handles the plotting of DET '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
462

463 464
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
465
        self._titles = self._titles or ['DET dev', 'DET eval']
466 467
        self._x_label = self._x_label or 'False Positive Rate (%)'
        self._y_label = self._y_label or 'False Negative Rate (%)'
468
        self._legend_loc = self._legend_loc or 'upper right'
469 470
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
471
        # custom defaults here
472 473
        if self._x_rotation is None:
            self._x_rotation = 50
474

475 476 477 478 479 480
        if self._axlim is None:
            self._axlim = [0.01, 99, 0.01, 99]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig) * 100

481
    def compute(self, idx, input_scores, input_names):
482 483
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
484 485
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
486 487
        dev_file = input_names[0]
        if self._eval:
488
            eval_neg, eval_pos = neg_list[1], pos_list[1]
489 490
            eval_file = input_names[1]

491
        mpl.figure(1)
492
        if self._eval and eval_neg is not None:
493 494
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
495
                linestyle=self._linestyles[idx],
496
                label=self._label('development', dev_file, idx)
497 498 499
            )
            if self._split:
                mpl.figure(2)
500
            linestyle = '--' if not self._split else self._linestyles[idx]
501
            plot.det(
502
                eval_neg, eval_pos, self._points, color=self._colors[idx],
503
                linestyle=linestyle,
504
                label=self._label('eval', eval_file, idx)
505
            )
506 507 508 509
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
510 511
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
512 513 514
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
515 516 517
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
518
                linestyle=self._linestyles[idx],
519
                label=self._label('development', dev_file, idx)
520 521
            )

522
    def _set_axis(self):
523
        plot.det_axis(self._axlim)
524

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
525

526 527
class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
528

529
    def __init__(self, ctx, scores, evaluation, func_load, hter='HTER'):
530
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
531
        if self._min_arg != 2:
532
            raise click.UsageError("EPC requires dev and eval score files")
533
        self._titles = self._titles or ['EPC'] * 2
534
        self._x_label = self._x_label or r'$\alpha$'
535
        self._y_label = self._y_label or hter + ' (%)'
536
        self._legend_loc = self._legend_loc or 'upper center'
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
537
        self._eval = True  # always eval data with EPC
538
        self._split = False
539
        self._nb_figs = 1
540
        self._far_at = None
541

542
    def compute(self, idx, input_scores, input_names):
543
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
544 545
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
546 547
        dev_file = input_names[0]
        if self._eval:
548
            eval_neg, eval_pos = neg_list[1], pos_list[1]
549 550
            eval_file = input_names[1]

551
        plot.epc(
552
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
553
            color=self._colors[idx], linestyle=self._linestyles[idx],
554
            label=self._label(
555
                'curve', dev_file + "_" + eval_file, idx
556
            )
557 558
        )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
559

560
class Hist(PlotBase):
561
    ''' Functional base class for histograms'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
562

563
    def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2):
564
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
565 566 567
        self._nbins = ctx.meta.get('n_bins', ['doane'])
        self._nhist_per_system = nhist_per_system
        self._nbins = check_list_value(
568
            self._nbins, nhist_per_system, 'n_bins',
569
            'histograms')
570
        self._thres = ctx.meta.get('thres')
571 572
        self._thres = check_list_value(
            self._thres, self.n_systems, 'thresholds')
573
        self._criterion = ctx.meta.get('criterion')
574
        # no vertical (threshold) is displayed
575
        self._no_line = ctx.meta.get('no_line', False)
576
        # subplot grid
577 578
        self._nrows = ctx.meta.get('n_row', 1)
        self._ncols = ctx.meta.get('n_col', 1)
579
        # do not display dev histo
580 581 582
        self._hide_dev = ctx.meta.get('hide_dev', False)
        # dev hist are displayed next to eval hist
        self._ncols *= 1 if self._hide_dev else 2
583
        self._nlegends = ctx.meta.get('legends_ncol', 3)
584
        self._legend_loc = self._legend_loc or 'upper center'
585
        # number of subplot on one page
586
        self._step_print = int(self._nrows * self._ncols)
587
        self._title_base = 'Scores'
588
        self._y_label = 'Probability density'
589
        self._x_label = 'Score values'
590
        self._end_setup_plot = False
591 592
        if self._legends is not None and len(self._legends) == self.n_systems \
           and not self._hide_dev:
593
            # use same legend for dev and eval if needed
594 595
            self._legends = [x for pair in zip(self._legends,self._legends)
                             for x in pair]
596

597
    def compute(self, idx, input_scores, input_names):
598
        ''' Draw histograms of negative and positive scores.'''
599
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
600
            self._get_neg_pos_thres(idx, input_scores, input_names)
601 602 603 604 605 606 607 608 609 610 611 612
        idx *= 1 if self._hide_dev or not self._eval else 2

        if not self._hide_dev or not self._eval:
            self._print_subplot(idx, dev_neg, dev_pos, threshold, False,
                                dflt_title="Dev scores")

        idx += 1 if self._eval and not self._hide_dev else 0
        if self._eval:
            self._print_subplot(idx, eval_neg, eval_pos, threshold,
                                not self._no_line, dflt_title="Eval scores")

    def _print_subplot(self, idx, neg, pos, threshold, draw_line, dflt_title):
613
        ''' print a subplot for the given score and subplot index'''
614 615 616 617 618 619 620
        n = idx % self._step_print
        col = n % self._ncols
        sub_plot_idx = n + 1
        axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx)
        self._setup_hist(neg, pos)
        if col == 0:
            axis.set_ylabel(self._y_label)
621
        # rest to be printed
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
622 623
        rest_print = self.n_systems - \
            int(idx / self._step_print) * self._step_print
624
        if n + self._ncols >= min(self._step_print, rest_print):
625
            axis.set_xlabel(self._x_label)
626
        axis.set_title(self._get_title(idx, dflt_title))
627
        label = "%s threshold%s" % (
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
628
            '' if self._criterion is None else
629 630
            self._criterion.upper(), ' (dev)' if self._eval else ''
        )
631
        if draw_line:
632
            self._lines(threshold, label, neg, pos, idx)
633

634
        mult = 2 if self._eval and not self._hide_dev else 1
635 636
        # if it was the last subplot of the page or the last subplot
        # to display, save figure
637
        if self._step_print == sub_plot_idx or idx == self.n_systems * mult - 1:
638 639
            # print legend on the page
            self.plot_legends()
640
            mpl.tight_layout()
641
            self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
642 643
            mpl.clf()
            mpl.figure()
644

645
    def _get_title(self, idx, dflt=None):
646
        ''' Get the histo title for the given idx'''
647
        title = self._legends[idx] if self._legends is not None \
648
            and idx < len(self._legends) else dflt
649
        title = title or self._title_base
650 651
        title = '' if title is not None and not title.replace(
            ' ', '') else title
652
        return title or ''
653

654 655
    def plot_legends(self):
        ''' Print legend on current page'''
656 657 658
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
659 660 661 662 663 664 665
            ali, ala = ax.get_legend_handles_labels()
            # avoid duplicates in legend
            for li, la in zip(ali, ala):
                if la not in labels:
                    lines.append(li)
                    labels.append(la)

666 667
        if self._disp_legend:
            mpl.gcf().legend(
668
                lines, labels, loc=self._legend_loc, fancybox=True,
669
                framealpha=0.5, ncol=self._nlegends,
670
                bbox_to_anchor=(0.55, 1.1),
671
            )
672

673
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
674
        ''' Get scores and threshod for the given system at index idx'''
675 676
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
677
        # can have several files for one system
678 679
        dev_neg = [neg_list[x] for x in range(0, length, 2)]
        dev_pos = [pos_list[x] for x in range(0, length, 2)]
680 681
        eval_neg = eval_pos = None
        if self._eval:
682
            eval_neg = [neg_list[x] for x in range(1, length, 2)]
683 684
            eval_pos = [pos_list[x] for x in range(1, length, 2)]

685
        threshold = utils.get_thres(
686
            self._criterion, dev_neg[0], dev_pos[0]
687
        ) if self._thres is None else self._thres[idx]
688
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
689

690
    def _density_hist(self, scores, n, **kwargs):
691
        ''' Plots one density histo'''
692
        n, bins, patches = mpl.hist(
693
            scores, density=True,
694
            bins=self._nbins[n],
695
            **kwargs
696 697 698
        )
        return (n, bins, patches)

699 700
    def _lines(self, threshold, label=None, neg=None, pos=None,
               idx=None, **kwargs):
701
        ''' Plots vertical line at threshold '''
702
        label = label or 'Threshold'
703 704 705 706
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
707
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
708 709

    def _setup_hist(self, neg, pos):
710 711 712 713 714
        ''' This function can be overwritten in derived classes

        Plots all the density histo required in one plot. Here negative and
        positive scores densities.
        '''
715
        self._density_hist(
716 717
            neg[0], n=0,
            label='Negatives', alpha=0.5, color='C3'
718 719
        )
        self._density_hist(
720 721
            pos[0], n=1,
            label='Positives', alpha=0.5, color='C0'
722
        )