Commit 5d9460a9 authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV
Browse files

Merge branch 'cross' into 'master'

Cross database testing evaluation
Adds a new command ``bob pad cross``

See merge request !53
parents 5a21e60d 7a8a04ba
Pipeline #25870 passed with stages
in 7 minutes and 13 seconds
......@@ -477,6 +477,8 @@ class SVM(Algorithm):
features_array = feature
features_array = features_array.astype('float64')
if not (self.machine_type == 'ONE_CLASS'): # two-class SVM case
probabilities = self.machine.predict_class_and_probabilities(
"""Prints Cross-db metrics analysis
import click
import json
import jinja2
import logging
import math
import os
import yaml
from import split
from bob.extension.scripts.click_helper import (
verbosity_option, bool_option, log_parameters)
from bob.measure import eer_threshold, farfrr
from bob.measure.script import common_options
from bob.measure.utils import get_fta
from gridtk.generator import expand
from tabulate import tabulate
logger = logging.getLogger(__name__)
$ bin/bob pad cross 'results/{{ evaluation.database }}/{{ algorithm }}/{{ evaluation.protocol }}/scores/scores-{{ group }}' \
-td replaymobile -d replaymobile -p grandtest -d oulunpu -p Protocol_1 \
-a replaymobile_frame-diff-svm \
-a replaymobile_qm-svm-64 \
-a replaymobile_lbp-svm-64 \
> replaymobile.rst &
@click.option('-d', '--database', 'databases', multiple=True, required=True,
help='Names of the evaluation databases')
@click.option('-p', '--protocol', 'protocols', multiple=True, required=True,
help='Names of the protocols of the evaluation databases')
@click.option('-a', '--algorithm', 'algorithms', multiple=True, required=True,
help='Names of the algorithms')
@click.option('-n', '--names', type=click.File('r'),
help='Name of algorithms to show in the table. Provide a path '
'to a json file maps algorithm names to names that you want to '
'see in the table.')
@click.option('-td', '--train-database', required=True,
help='The database that was used to train the algorithms.')
@click.option('-g', '--group', 'groups', multiple=True, show_default=True,
default=['train', 'dev', 'eval'])
@bool_option('sort', 's', 'whether the table should be sorted.', True)
def cross(ctx, score_jinja_template, databases, protocols, algorithms,
names, train_database, groups, sort, **kwargs):
"""Cross-db analysis metrics
names = {} if names is None else json.load(names)
env = jinja2.Environment(undefined=jinja2.StrictUndefined)
data = {
'evaluation': [{'database': db, 'protocol': proto}
for db, proto in zip(databases, protocols)],
'algorithm': algorithms,
'group': groups,
metrics = {}
for variables in expand(yaml.dump(data, Dumper=yaml.SafeDumper)):
score_path = env.from_string(score_jinja_template).render(variables)
database, protocol, algorithm, group = \
variables['evaluation']['database'], \
variables['evaluation']['protocol'], \
variables['algorithm'], variables['group']
# if algorithm name does not have train_database name in it.
if train_database not in algorithm and database != train_database:
score_path = score_path.replace(
algorithm, database + '_' + algorithm)
if not os.path.exists(score_path):
metrics[(database, protocol, algorithm, group)] = \
(float('nan'), ) * 5
(neg, pos), fta = get_fta(split(score_path))
if group == 'eval':
threshold = metrics[(database, protocol, algorithm, 'dev')][1]
threshold = eer_threshold(neg, pos)
except RuntimeError:
logger.error("Something wrong with {}".format(score_path))
far, frr = farfrr(neg, pos, threshold)
hter = (far + frr) / 2
metrics[(database, protocol, algorithm, group)] = \
(hter, threshold, fta, far, frr)
logger.debug('metrics: %s', metrics)
headers = ["Algorithms"]
for db in databases:
headers += [db + "\nEER_t", "\nEER_d", "\nAPCER", "\nBPCER", "\nACER"]
rows = []
# sort the algorithms based on HTER test, EER dev, EER train
if sort:
train_protocol = protocols[databases.index(train_database)]
def sort_key(alg):
r = []
for grp in ('eval', 'dev', 'train'):
hter = metrics[(train_database, train_protocol, alg, group)][0]
r.append(1 if math.isnan(hter) else hter)
return tuple(r)
algorithms = sorted(algorithms, key=sort_key)
for algorithm in algorithms:
name = algorithm.replace(train_database + '_', '')
name = name.replace(train_protocol + '_', '')
name = names.get(name, name)
for database, protocol in zip(databases, protocols):
cell = []
for group in groups:
hter, threshold, fta, far, frr = metrics[(
database, protocol, algorithm, group)]
if group == 'eval':
cell += [far, frr, hter]
cell += [hter]
cell = [round(c * 100, 1) for c in cell]
title = ' Trained on {} '.format(train_database)
title_line = '\n' + '=' * len(title) + '\n'
click.echo(title_line + title + title_line, file=ctx.meta['log'])
click.echo(tabulate(rows, headers, ctx.meta['tablefmt'], floatfmt=".1f"),
"""Finalizes the scores that are produced by
import click
import numpy
import logging
from bob.extension.scripts.click_helper import (
verbosity_option, log_parameters)
logger = logging.getLogger(__name__)
@click.command(name='finalize-scores', epilog='''\b
$ bin/bob pad finalize_scores /path/to/scores-dev
$ bin/bob pad finalize_scores /path/to/scores-{dev,eval}
@click.argument('scores', type=click.Path(exists=True, dir_okay=False),
@click.option('-m', '--method', default='mean',
type=click.Choice(['mean', 'min', 'max']), show_default=True,
help='The method to use when finalizing the scores.')
def finalize_scores(scores, method, **kwargs):
"""Finalizes the scores given by
When using bob.pad.base, Algorithms can produce several score values for
each unique sample. You can use this script to average (or min/max) these
scores to have one final score per sample.
The conversion is done in-place. The order of scores will change.
mean = {'mean': numpy.nanmean, 'max': numpy.nanmax, 'min': numpy.nanmin}[method]
for path in scores:
new_lines = []
with open(path) as f:
old_lines = f.readlines()
for i, line in enumerate(old_lines):
uniq, s = line.strip().rsplit(maxsplit=1)
s = float(s)
if i == 0:
last_line = uniq
last_scores = []
if uniq == last_line:
new_lines.append('{} {}\n'.format(
last_line, mean(last_scores)))
last_scores = [s]
last_line = uniq
else: # this else is for the for loop
new_lines.append('{} {}\n'.format(last_line, mean(last_scores)))
with open(path, 'w') as f:
......@@ -49,9 +49,9 @@ def convert_and_prepare_features(features):
if isinstance(
features[0], # if FrameContainer convert to 2D numpy array
return convert_list_of_frame_cont_to_array(features)
return convert_list_of_frame_cont_to_array(features).astype('float64')
return np.vstack(features)
return np.vstack(features).astype('float64')
def convert_list_of_frame_cont_to_array(frame_containers):
......@@ -256,9 +256,9 @@ def mean_std_normalize(features,
features_mean = np.mean(features, axis=0)
features_std = np.std(features, axis=0)
row_norm_list = []
for row in features: # row is a sample
......@@ -147,6 +147,8 @@ setup(
'epc = bob.pad.base.script.pad_commands:epc',
'gen = bob.pad.base.script.pad_commands:gen',
'evaluate = bob.pad.base.script.pad_commands:evaluate',
'cross = bob.pad.base.script.cross:cross',
'finalize-scores = bob.pad.base.script.finalize_scores:finalize_scores',
# bob vuln scripts
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment