Skip to content
Snippets Groups Projects
Commit 587f4ae9 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'metrics' into 'master'

Various fixes

See merge request !167
parents 628f83b0 a0bb138d
No related branches found
No related tags found
1 merge request!167Various fixes
Pipeline #
......@@ -32,6 +32,8 @@ def rank_option(**kwargs):
@common_options.metrics_command(common_options.METRICS_HELP.format(
names='FtA, FAR, FRR, FMR, FMNR, HTER',
criteria=CRITERIA, score_format=SCORE_FORMAT,
hter_note='Note that FAR = FMR * (1 - FtA), FRR = FtA + FMNR * (1 - FtA) '
'and HTER = (FMR + FMNR) / 2',
command='bob bio metrics'), criteria=CRITERIA)
@common_options.cost_option()
def metrics(ctx, scores, evaluation, **kwargs):
......
......@@ -5,15 +5,17 @@ import click
import matplotlib.pyplot as mpl
import bob.measure.script.figure as measure_figure
import bob.measure
from bob.measure import plot
from bob.measure import (plot, utils)
from tabulate import tabulate
import logging
LOGGER = logging.getLogger("bob.bio.base")
class Roc(measure_figure.Roc):
def __init__(self, ctx, scores, evaluation, func_load):
super(Roc, self).__init__(ctx, scores, evaluation, func_load)
self._x_label = ctx.meta.get('x_label') or 'False Match Rate'
default_y_label = '1 - False Non Match Rate' if self._semilogx \
self._x_label = ctx.meta.get('x_label') or 'FMR'
default_y_label = '1 - FNMR' if self._semilogx \
else 'False Non Match Rate'
self._y_label = ctx.meta.get('y_label') or default_y_label
......@@ -21,8 +23,8 @@ class Roc(measure_figure.Roc):
class Det(measure_figure.Det):
def __init__(self, ctx, scores, evaluation, func_load):
super(Det, self).__init__(ctx, scores, evaluation, func_load)
self._x_label = ctx.meta.get('x_label') or 'False Match Rate (%)'
self._y_label = ctx.meta.get('y_label') or 'False Non Match Rate (%)'
self._x_label = ctx.meta.get('x_label') or 'FMR (%)'
self._y_label = ctx.meta.get('y_label') or 'FNMR (%)'
class Cmc(measure_figure.PlotBase):
......@@ -31,7 +33,7 @@ class Cmc(measure_figure.PlotBase):
def __init__(self, ctx, scores, evaluation, func_load):
super(Cmc, self).__init__(ctx, scores, evaluation, func_load)
self._semilogx = ctx.meta.get('semilogx', True)
self._titles = self._titles or ['CMC dev', 'CMC eval']
self._titles = self._titles or ['CMC dev.', 'CMC eval.']
self._x_label = self._x_label or 'Rank'
self._y_label = self._y_label or 'Identification rate'
self._max_R = 0
......@@ -42,10 +44,11 @@ class Cmc(measure_figure.PlotBase):
mpl.figure(1)
if self._eval:
linestyle = '-' if not self._split else self._linestyles[idx]
LOGGER.info("CMC dev. curve using %s", input_names[0])
rank = plot.cmc(
input_scores[0], logx=self._semilogx,
color=self._colors[idx], linestyle=linestyle,
label=self._label('dev', input_names[0], idx)
label=self._label('dev.', idx)
)
self._max_R = max(rank, self._max_R)
linestyle = '--'
......@@ -53,17 +56,19 @@ class Cmc(measure_figure.PlotBase):
mpl.figure(2)
linestyle = self._linestyles[idx]
LOGGER.info("CMC eval. curve using %s", input_names[1])
rank = plot.cmc(
input_scores[1], logx=self._semilogx,
color=self._colors[idx], linestyle=linestyle,
label=self._label('eval', input_names[1], idx)
label=self._label('eval.', idx)
)
self._max_R = max(rank, self._max_R)
else:
LOGGER.info("CMC dev. curve using %s", input_names[0])
rank = plot.cmc(
input_scores[0], logx=self._semilogx,
color=self._colors[idx], linestyle=self._linestyles[idx],
label=self._label('dev', input_names[0], idx)
label=self._label('dev.', idx)
)
self._max_R = max(rank, self._max_R)
......@@ -77,7 +82,7 @@ class Dir(measure_figure.PlotBase):
self._rank = ctx.meta.get('rank', 1)
self._titles = self._titles or ['DIR curve'] * 2
self._x_label = self._x_label or 'False Alarm Rate'
self._y_label = self._y_label or 'Detection and Identification Rate'
self._y_label = self._y_label or 'DIR'
def compute(self, idx, input_scores, input_names):
''' Plot DIR for dev and eval data using
......@@ -85,26 +90,29 @@ class Dir(measure_figure.PlotBase):
mpl.figure(1)
if self._eval:
linestyle = '-' if not self._split else self._linestyles[idx]
LOGGER.info("DIR dev. curve using %s", input_names[0])
plot.detection_identification_curve(
input_scores[0], rank=self._rank, logx=self._semilogx,
color=self._colors[idx], linestyle=linestyle,
label=self._label('dev', input_names[0], idx)
label=self._label('dev', idx)
)
linestyle = '--'
if self._split:
mpl.figure(2)
linestyle = self._linestyles[idx]
LOGGER.info("DIR eval. curve using %s", input_names[1])
plot.detection_identification_curve(
input_scores[1], rank=self._rank, logx=self._semilogx,
color=self._colors[idx], linestyle=linestyle,
label=self._label('eval', input_names[1], idx)
label=self._label('eval', idx)
)
else:
LOGGER.info("DIR dev. curve using %s", input_names[0])
plot.detection_identification_curve(
input_scores[0], rank=self._rank, logx=self._semilogx,
color=self._colors[idx], linestyle=self._linestyles[idx],
label=self._label('dev', input_names[0], idx)
label=self._label('dev', idx)
)
if self._min_dig is not None:
......@@ -114,6 +122,14 @@ class Dir(measure_figure.PlotBase):
class Metrics(measure_figure.Metrics):
''' Compute metrics from score files'''
def __init__(self, ctx, scores, evaluation, func_load,
names=('Failure to Acquire', 'False Match Rate',
'False Non Match Rate', 'False Accept Rate',
'False Reject Rate', 'Half Total Error Rate')):
super(Metrics, self).__init__(
ctx, scores, evaluation, func_load, names
)
def init_process(self):
if self._criterion == 'rr':
self._thres = [None] * self.n_systems if self._thres is None else \
......@@ -122,7 +138,7 @@ class Metrics(measure_figure.Metrics):
def compute(self, idx, input_scores, input_names):
''' Compute metrics for the given criteria'''
title = self._legends[idx] if self._legends is not None else None
headers = ['' or title, 'Development %s' % input_names[0]]
headers = ['' or title, 'Dev. %s' % input_names[0]]
if self._eval and input_scores[1] is not None:
headers.append('eval % s' % input_names[1])
if self._criterion == 'rr':
......@@ -205,12 +221,28 @@ class Metrics(measure_figure.Metrics):
tabulate(raws, headers, self._tablefmt), file=self.log_file
)
else:
self.names = (
'Failure to Acquire', 'False Match Rate',
'False Non Match Rate', 'False Accept Rate',
'False Reject Rate', 'Half Total Error Rate'
)
super(Metrics, self).compute(idx, input_scores, input_names)
title = self._legends[idx] if self._legends is not None else None
all_metrics = self._get_all_metrics(idx, input_scores, input_names)
headers = [' ' or title, 'Development']
rows = [[self.names[0], all_metrics[0][0]],
[self.names[1], all_metrics[0][1]],
[self.names[2], all_metrics[0][2]],
[self.names[3], all_metrics[0][3]],
[self.names[4], all_metrics[0][4]],
[self.names[5], all_metrics[0][5]]]
if self._eval:
# computes statistics for the eval set based on the threshold a
# priori
headers.append('Evaluation')
rows[0].append(all_metrics[1][0])
rows[1].append(all_metrics[1][1])
rows[2].append(all_metrics[1][2])
rows[3].append(all_metrics[1][3])
rows[4].append(all_metrics[1][4])
rows[5].append(all_metrics[1][5])
click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file)
class MultiMetrics(measure_figure.MultiMetrics):
......
......@@ -9,6 +9,7 @@ from click.types import FLOAT
from bob.extension.scripts.click_helper import verbosity_option
import bob.core
from bob.io.base import create_directories_safe
from bob.measure.script import common_options
logger = logging.getLogger(__name__)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment