Commit 587f4ae9 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Merge branch 'metrics' into 'master'

Various fixes

See merge request !167
parents 628f83b0 a0bb138d
Pipeline #21544 passed with stages
in 27 minutes and 28 seconds
......@@ -32,6 +32,8 @@ def rank_option(**kwargs):
@common_options.metrics_command(common_options.METRICS_HELP.format(
names='FtA, FAR, FRR, FMR, FMNR, HTER',
criteria=CRITERIA, score_format=SCORE_FORMAT,
hter_note='Note that FAR = FMR * (1 - FtA), FRR = FtA + FMNR * (1 - FtA) '
'and HTER = (FMR + FMNR) / 2',
command='bob bio metrics'), criteria=CRITERIA)
@common_options.cost_option()
def metrics(ctx, scores, evaluation, **kwargs):
......
......@@ -5,15 +5,17 @@ import click
import matplotlib.pyplot as mpl
import bob.measure.script.figure as measure_figure
import bob.measure
from bob.measure import plot
from bob.measure import (plot, utils)
from tabulate import tabulate
import logging
LOGGER = logging.getLogger("bob.bio.base")
class Roc(measure_figure.Roc):
def __init__(self, ctx, scores, evaluation, func_load):
super(Roc, self).__init__(ctx, scores, evaluation, func_load)
self._x_label = ctx.meta.get('x_label') or 'False Match Rate'
default_y_label = '1 - False Non Match Rate' if self._semilogx \
self._x_label = ctx.meta.get('x_label') or 'FMR'
default_y_label = '1 - FNMR' if self._semilogx \
else 'False Non Match Rate'
self._y_label = ctx.meta.get('y_label') or default_y_label
......@@ -21,8 +23,8 @@ class Roc(measure_figure.Roc):
class Det(measure_figure.Det):
def __init__(self, ctx, scores, evaluation, func_load):
super(Det, self).__init__(ctx, scores, evaluation, func_load)
self._x_label = ctx.meta.get('x_label') or 'False Match Rate (%)'
self._y_label = ctx.meta.get('y_label') or 'False Non Match Rate (%)'
self._x_label = ctx.meta.get('x_label') or 'FMR (%)'
self._y_label = ctx.meta.get('y_label') or 'FNMR (%)'
class Cmc(measure_figure.PlotBase):
......@@ -31,7 +33,7 @@ class Cmc(measure_figure.PlotBase):
def __init__(self, ctx, scores, evaluation, func_load):
super(Cmc, self).__init__(ctx, scores, evaluation, func_load)
self._semilogx = ctx.meta.get('semilogx', True)
self._titles = self._titles or ['CMC dev', 'CMC eval']
self._titles = self._titles or ['CMC dev.', 'CMC eval.']
self._x_label = self._x_label or 'Rank'
self._y_label = self._y_label or 'Identification rate'
self._max_R = 0
......@@ -42,10 +44,11 @@ class Cmc(measure_figure.PlotBase):
mpl.figure(1)
if self._eval:
linestyle = '-' if not self._split else self._linestyles[idx]
LOGGER.info("CMC dev. curve using %s", input_names[0])
rank = plot.cmc(
input_scores[0], logx=self._semilogx,
color=self._colors[idx], linestyle=linestyle,
label=self._label('dev', input_names[0], idx)
label=self._label('dev.', idx)
)
self._max_R = max(rank, self._max_R)
linestyle = '--'
......@@ -53,17 +56,19 @@ class Cmc(measure_figure.PlotBase):
mpl.figure(2)
linestyle = self._linestyles[idx]
LOGGER.info("CMC eval. curve using %s", input_names[1])
rank = plot.cmc(
input_scores[1], logx=self._semilogx,
color=self._colors[idx], linestyle=linestyle,
label=self._label('eval', input_names[1], idx)
label=self._label('eval.', idx)
)
self._max_R = max(rank, self._max_R)
else:
LOGGER.info("CMC dev. curve using %s", input_names[0])
rank = plot.cmc(
input_scores[0], logx=self._semilogx,
color=self._colors[idx], linestyle=self._linestyles[idx],
label=self._label('dev', input_names[0], idx)
label=self._label('dev.', idx)
)
self._max_R = max(rank, self._max_R)
......@@ -77,7 +82,7 @@ class Dir(measure_figure.PlotBase):
self._rank = ctx.meta.get('rank', 1)
self._titles = self._titles or ['DIR curve'] * 2
self._x_label = self._x_label or 'False Alarm Rate'
self._y_label = self._y_label or 'Detection and Identification Rate'
self._y_label = self._y_label or 'DIR'
def compute(self, idx, input_scores, input_names):
''' Plot DIR for dev and eval data using
......@@ -85,26 +90,29 @@ class Dir(measure_figure.PlotBase):
mpl.figure(1)
if self._eval:
linestyle = '-' if not self._split else self._linestyles[idx]
LOGGER.info("DIR dev. curve using %s", input_names[0])
plot.detection_identification_curve(
input_scores[0], rank=self._rank, logx=self._semilogx,
color=self._colors[idx], linestyle=linestyle,
label=self._label('dev', input_names[0], idx)
label=self._label('dev', idx)
)
linestyle = '--'
if self._split:
mpl.figure(2)
linestyle = self._linestyles[idx]
LOGGER.info("DIR eval. curve using %s", input_names[1])
plot.detection_identification_curve(
input_scores[1], rank=self._rank, logx=self._semilogx,
color=self._colors[idx], linestyle=linestyle,
label=self._label('eval', input_names[1], idx)
label=self._label('eval', idx)
)
else:
LOGGER.info("DIR dev. curve using %s", input_names[0])
plot.detection_identification_curve(
input_scores[0], rank=self._rank, logx=self._semilogx,
color=self._colors[idx], linestyle=self._linestyles[idx],
label=self._label('dev', input_names[0], idx)
label=self._label('dev', idx)
)
if self._min_dig is not None:
......@@ -114,6 +122,14 @@ class Dir(measure_figure.PlotBase):
class Metrics(measure_figure.Metrics):
''' Compute metrics from score files'''
def __init__(self, ctx, scores, evaluation, func_load,
names=('Failure to Acquire', 'False Match Rate',
'False Non Match Rate', 'False Accept Rate',
'False Reject Rate', 'Half Total Error Rate')):
super(Metrics, self).__init__(
ctx, scores, evaluation, func_load, names
)
def init_process(self):
if self._criterion == 'rr':
self._thres = [None] * self.n_systems if self._thres is None else \
......@@ -122,7 +138,7 @@ class Metrics(measure_figure.Metrics):
def compute(self, idx, input_scores, input_names):
''' Compute metrics for the given criteria'''
title = self._legends[idx] if self._legends is not None else None
headers = ['' or title, 'Development %s' % input_names[0]]
headers = ['' or title, 'Dev. %s' % input_names[0]]
if self._eval and input_scores[1] is not None:
headers.append('eval % s' % input_names[1])
if self._criterion == 'rr':
......@@ -205,12 +221,28 @@ class Metrics(measure_figure.Metrics):
tabulate(raws, headers, self._tablefmt), file=self.log_file
)
else:
self.names = (
'Failure to Acquire', 'False Match Rate',
'False Non Match Rate', 'False Accept Rate',
'False Reject Rate', 'Half Total Error Rate'
)
super(Metrics, self).compute(idx, input_scores, input_names)
title = self._legends[idx] if self._legends is not None else None
all_metrics = self._get_all_metrics(idx, input_scores, input_names)
headers = [' ' or title, 'Development']
rows = [[self.names[0], all_metrics[0][0]],
[self.names[1], all_metrics[0][1]],
[self.names[2], all_metrics[0][2]],
[self.names[3], all_metrics[0][3]],
[self.names[4], all_metrics[0][4]],
[self.names[5], all_metrics[0][5]]]
if self._eval:
# computes statistics for the eval set based on the threshold a
# priori
headers.append('Evaluation')
rows[0].append(all_metrics[1][0])
rows[1].append(all_metrics[1][1])
rows[2].append(all_metrics[1][2])
rows[3].append(all_metrics[1][3])
rows[4].append(all_metrics[1][4])
rows[5].append(all_metrics[1][5])
click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file)
class MultiMetrics(measure_figure.MultiMetrics):
......
......@@ -9,6 +9,7 @@ from click.types import FLOAT
from bob.extension.scripts.click_helper import verbosity_option
import bob.core
from bob.io.base import create_directories_safe
from bob.measure.script import common_options
logger = logging.getLogger(__name__)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment