Commit 3f342953 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Merge branch 'score-writer' into 'master'

Switch to CSV scores for pipelines and analysis commands

See merge request !87
parents f3417121 0937127a
Pipeline #51536 passed with stages
in 9 minutes and 17 seconds
...@@ -3,10 +3,20 @@ ...@@ -3,10 +3,20 @@
# Fri Dec 7 12:33:37 CET 2012 # Fri Dec 7 12:33:37 CET 2012
"""Utility functions for computation of EPSC curve and related measurement""" """Utility functions for computation of EPSC curve and related measurement"""
from bob.measure import far_threshold, eer_threshold, min_hter_threshold, farfrr, frr_threshold from bob.measure import (
from bob.bio.base.score.load import four_column far_threshold,
eer_threshold,
min_hter_threshold,
farfrr,
frr_threshold,
)
from bob.bio.base.score.load import _iterate_csv_score_file
from collections import defaultdict from collections import defaultdict
import re import re
import numpy
import logging
logger = logging.getLogger(__name__)
def calc_threshold(method, pos, negs, all_negs, far_value=None, is_sorted=False): def calc_threshold(method, pos, negs, all_negs, far_value=None, is_sorted=False):
...@@ -116,83 +126,117 @@ def apcer_bpcer(threshold, pos, *negs): ...@@ -116,83 +126,117 @@ def apcer_bpcer(threshold, pos, *negs):
return apcers, max(apcers), bpcer return apcers, max(apcers), bpcer
def negatives_per_pai_and_positives(filename, regexps=None, regexp_column="real_id"): def split_csv_pad_per_pai(filename, regexps=[], regexp_column="attack_type"):
"""Returns scores for Bona-Fide samples and scores for each PAI. """Returns scores for Bona-Fide samples and scores for each PAI.
By default, the real_id column (second column) is used as indication for each By default, the real_id column (second column) is used as indication for each
Presentation Attack Instrument (PAI). Presentation Attack Instrument (PAI).
For example, if you have scores like: For example, with default regexps and regexp_column, if you have scores like:
001 001 bona_fide_sample_1_path 0.9 claimed_id, test_label, is_bonafide, attack_type, score
001 print print_sample_1_path 0.6 001, bona_fide_sample_1_path, True, , 0.9
001 print print_sample_2_path 0.6 001, print_sample_1_path, False, print, 0.6
001 replay replay_sample_1_path 0.2 001, print_sample_2_path, False, print, 0.6
001 replay replay_sample_2_path 0.2 001, replay_sample_1_path, False, replay, 0.2
001 mask mask_sample_1_path 0.5 001, replay_sample_2_path, False, replay, 0.2
001 mask mask_sample_2_path 0.5 001, mask_sample_1_path, False, mask, 0.5
this function will return 3 sets of negative scores (for each print, replay, and 001, mask_sample_2_path, False, mask, 0.5
mask PAIs). this function will return 1 set of positive scores, and 3 sets of negative scores
(for each print, replay, and mask PAIs).
Otherwise, you can provide a list regular expressions that match each PAI. Otherwise, you can provide a list regular expressions that match each PAI.
For example, if you have scores like: For example, with regexps as ['print', 'replay', 'mask'], if you have scores like:
001 001 bona_fide_sample_1_path 0.9 claimed_id, test_label, is_bonafide, attack_type, score
001 print/1 print_sample_1_path 0.6 001, bona_fide_sample_1_path, True, , 0.9
001 print/2 print_sample_2_path 0.6 001, print_sample_1_path, False, print/1, 0.6
001 replay/1 replay_sample_1_path 0.2 001, print_sample_2_path, False, print/2, 0.6
001 replay/2 replay_sample_2_path 0.2 001, replay_sample_1_path, False, replay/1, 0.2
001 mask/1 mask_sample_1_path 0.5 001, replay_sample_2_path, False, replay/2, 0.2
001 mask/2 mask_sample_2_path 0.5 001, mask_sample_1_path, False, mask/1, 0.5
and give a list of regexps as ('print', 'replay', 'mask') the function will return 3 001, mask_sample_2_path, False, mask/2, 0.5
sets of negative scores (for each print, replay, and mask PAIs). the function will return 3 sets of negative scores (for print, replay, and mask
PAIs, given in regexp).
Parameters Parameters
---------- ----------
filename : str filename : str
Path to the score file. Path to the score file.
regexps : None, optional regexps : List of str, optional
A list of regular expressions that match each PAI. If not given, the values in A list of regular expressions that match each PAI. If not given, the values in
the real_id column are used to find scores for different PAIs. the column pointed by regexp_column are used to find scores for different PAIs.
regexp_column : str, optional regexp_column : str, optional
If a list of regular expressions are given, those patterns will be matched If a list of regular expressions are given, those patterns will be matched
against the values in this column. against the values in this column. default: ``attack_type``
Returns Returns
------- -------
tuple tuple ([positives], {'pai_name': [negatives]})
A tuple containing pos scores and a dict of negative scores mapping PAIs to A tuple containing positive scores and a dict of negative scores mapping PAIs
their scores. names to their respective scores.
Raises Raises
------ ------
ValueError ValueError
If none of the given regular expressions match the values in regexp_column. If none of the given regular expressions match the values in regexp_column.
KeyError
If regexp_column is not a column of the CSV file.
""" """
pos = [] pos = []
negs = defaultdict(list) negs = defaultdict(list)
logger.debug(f"Loading CSV score file: '{filename}'")
if regexps: if regexps:
regexps = [re.compile(pattern) for pattern in regexps] regexps = [re.compile(pattern) for pattern in regexps]
assert regexp_column in ("claimed_id", "real_id", "test_label"), regexp_column
for claimed_id, real_id, test_label, score in four_column(filename): for row in _iterate_csv_score_file(filename):
# if it is a Bona-Fide score # if it is a Bona-Fide score
if claimed_id == real_id: if row["is_bonafide"].lower() == "true":
pos.append(score) pos.append(row["score"])
continue continue
if not regexps: if not regexps:
negs[real_id].append(score) negs[row[regexp_column]].append(row["score"])
continue continue
# if regexps is not None or empty and is not a Bona-Fide score # if regexps is not None or empty and is not a Bona-Fide score
string = {
"claimed_id": claimed_id,
"real_id": real_id,
"test_label": test_label,
}[regexp_column]
for pattern in regexps: for pattern in regexps:
if pattern.match(string): if pattern.search(row[regexp_column]):
negs[pattern.pattern].append(score) negs[pattern.pattern].append(row["score"])
break break
else: # this else is for the for loop: ``for pattern in regexps:`` else: # this else is for the for loop: ``for pattern in regexps:``
raise ValueError( raise ValueError(
f"No regexps: {regexps} match `{string}' from `{regexp_column}' column" f"No regexps: {regexps} match `{row[regexp_column]}' "
f"from `{regexp_column}' column."
) )
logger.debug(f"Found {len(negs)} different PAIs names: {list(negs.keys())}")
return pos, negs return pos, negs
def split_csv_pad(filename):
"""Loads PAD scores from a CSV score file, splits them by attack vs bonafide.
The CSV must contain a ``is_bonafide`` column with each field either
``True`` or ``False`` (case insensitive).
Parameters
----------
filename: str
The path to a CSV file containing all the scores.
Returns
-------
(attack, bonafide): Tuple of 1D-arrays
The negative (attacks) and positives (bonafide) scores.
"""
logger.debug(f"Loading CSV score file: '{filename}'")
split_scores = defaultdict(list)
for row in _iterate_csv_score_file(filename):
if row["is_bonafide"].lower() == "true":
split_scores["bonafide"].append(row["score"])
else:
split_scores["attack"].append(row["score"])
logger.debug(
f"Found {len(split_scores['attack'])} negative (attack), and"
f"{len(split_scores['bonafide'])} positive (bonafide) scores."
)
# Cast the scores to numpy float
for key, scores in split_scores.items():
split_scores[key] = numpy.array(scores, dtype=numpy.float64)
return split_scores["attack"], split_scores["bonafide"]
...@@ -5,16 +5,14 @@ from bob.measure.script import common_options ...@@ -5,16 +5,14 @@ from bob.measure.script import common_options
from bob.extension.scripts.click_helper import verbosity_option from bob.extension.scripts.click_helper import verbosity_option
import bob.bio.base.script.gen as bio_gen import bob.bio.base.script.gen as bio_gen
import bob.measure.script.figure as measure_figure import bob.measure.script.figure as measure_figure
from bob.bio.base.score import load
from . import pad_figure as figure from . import pad_figure as figure
from .error_utils import negatives_per_pai_and_positives from .error_utils import split_csv_pad, split_csv_pad_per_pai
from functools import partial from functools import partial
from csv import DictWriter
import numpy
import os
SCORE_FORMAT = ( SCORE_FORMAT = "Files must be in CSV format."
"Files must be 4-col or 5-col format, see "
":py:func:`bob.bio.base_legacy.score.load.four_column` and"
":py:func:`bob.bio.base_legacy.score.load.five_column`."
)
CRITERIA = ( CRITERIA = (
"eer", "eer",
"min-hter", "min-hter",
...@@ -53,7 +51,7 @@ def metrics_option( ...@@ -53,7 +51,7 @@ def metrics_option(
help="List of metrics to print. Provide a string with comma separated metric " help="List of metrics to print. Provide a string with comma separated metric "
"names. For possible values see the default value.", "names. For possible values see the default value.",
default="apcer_pais,apcer_ap,bpcer,acer,fta,fpr,fnr,hter,far,frr,precision,recall,f1_score,auc,auc-log-scale", default="apcer_pais,apcer_ap,bpcer,acer,fta,fpr,fnr,hter,far,frr,precision,recall,f1_score,auc,auc-log-scale",
**kwargs **kwargs,
): ):
"""The metrics option""" """The metrics option"""
...@@ -71,7 +69,7 @@ def metrics_option( ...@@ -71,7 +69,7 @@ def metrics_option(
help=help, help=help,
show_default=True, show_default=True,
callback=callback, callback=callback,
**kwargs **kwargs,
)(func) )(func)
return custom_metrics_option return custom_metrics_option
...@@ -80,7 +78,7 @@ def metrics_option( ...@@ -80,7 +78,7 @@ def metrics_option(
def regexps_option( def regexps_option(
help="A list of regular expressions (by repeating this option) to be used to " help="A list of regular expressions (by repeating this option) to be used to "
"categorize PAIs. Each regexp must match one type of PAI.", "categorize PAIs. Each regexp must match one type of PAI.",
**kwargs **kwargs,
): ):
def custom_regexps_option(func): def custom_regexps_option(func):
def callback(ctx, param, value): def callback(ctx, param, value):
...@@ -94,7 +92,7 @@ def regexps_option( ...@@ -94,7 +92,7 @@ def regexps_option(
multiple=True, multiple=True,
help=help, help=help,
callback=callback, callback=callback,
**kwargs **kwargs,
)(func) )(func)
return custom_regexps_option return custom_regexps_option
...@@ -102,7 +100,7 @@ def regexps_option( ...@@ -102,7 +100,7 @@ def regexps_option(
def regexp_column_option( def regexp_column_option(
help="The column in the score files to match the regular expressions against.", help="The column in the score files to match the regular expressions against.",
**kwargs **kwargs,
): ):
def custom_regexp_column_option(func): def custom_regexp_column_option(func):
def callback(ctx, param, value): def callback(ctx, param, value):
...@@ -112,35 +110,110 @@ def regexp_column_option( ...@@ -112,35 +110,110 @@ def regexp_column_option(
return click.option( return click.option(
"-rc", "-rc",
"--regexp-column", "--regexp-column",
default="real_id", default="attack_type",
type=click.Choice(("claimed_id", "real_id", "test_label")),
help=help, help=help,
show_default=True, show_default=True,
callback=callback, callback=callback,
**kwargs **kwargs,
)(func) )(func)
return custom_regexp_column_option return custom_regexp_column_option
def gen_pad_csv_scores(
filename, mean_match, mean_attacks, n_attack_types, n_clients, n_samples
):
"""Generates a CSV file containing random scores for PAD."""
columns = [
"claimed_id",
"test_label",
"is_bonafide",
"attack_type",
"sample_n",
"score",
]
with open(filename, "w") as f:
writer = DictWriter(f, fieldnames=columns)
writer.writeheader()
# Bonafide rows
for client_id in range(n_clients):
for sample in range(n_samples):
writer.writerow(
{
"claimed_id": client_id,
"test_label": f"client/real/{client_id:03d}",
"is_bonafide": "True",
"attack_type": None,
"sample_n": sample,
"score": numpy.random.normal(loc=mean_match),
}
)
# Attacks rows
for attack_type in range(n_attack_types):
for client_id in range(n_clients):
for sample in range(n_samples):
writer.writerow(
{
"claimed_id": client_id,
"test_label": f"client/attack/{client_id:03d}",
"is_bonafide": "False",
"attack_type": f"type_{attack_type}",
"sample_n": sample,
"score": numpy.random.normal(
loc=mean_attacks[attack_type % len(mean_attacks)]
),
}
)
@click.command() @click.command()
@click.argument("outdir") @click.argument("outdir")
@click.option("-mm", "--mean-match", default=10, type=click.FLOAT, show_default=True) @click.option("-mm", "--mean-match", default=10, type=click.FLOAT, show_default=True)
@click.option( @click.option(
"-mnm", "--mean-non-match", default=-10, type=click.FLOAT, show_default=True "-ma",
"--mean-attacks",
default=[-10, -6],
type=click.FLOAT,
show_default=True,
multiple=True,
) )
@click.option("-n", "--n-sys", default=1, type=click.INT, show_default=True) @click.option("-c", "--n-clients", default=10, type=click.INT, show_default=True)
@click.option("-s", "--n-samples", default=2, type=click.INT, show_default=True)
@click.option("-a", "--n-attacks", default=2, type=click.INT, show_default=True)
@verbosity_option() @verbosity_option()
@click.pass_context @click.pass_context
def gen(ctx, outdir, mean_match, mean_non_match, n_sys, **kwargs): def gen(
ctx, outdir, mean_match, mean_attacks, n_clients, n_samples, n_attacks, **kwargs
):
"""Generate random scores. """Generate random scores.
Generates random scores in 4col or 5col format. The scores are generated Generates random scores in CSV format. The scores are generated
using Gaussian distribution whose mean is an input using Gaussian distribution whose mean is an input
parameter. The generated scores can be used as hypothetical datasets. parameter. The generated scores can be used as hypothetical datasets.
Invokes :py:func:`bob.bio.base.script.commands.gen`. n-attacks defines the number of different type of attacks generated (like print and
mask). When multiples attacks are present, the mean-attacks option can be set
multiple times, specifying the mean of each attack scores distribution.
Example:
bob pad gen results/generated/scores-dev.csv -a 3 -ma 2 -ma 5 -ma 7 -mm 8
""" """
ctx.meta["five_col"] = False numpy.random.seed(0)
ctx.forward(bio_gen.gen) gen_pad_csv_scores(
os.path.join(outdir, "scores-dev.csv"),
mean_match,
mean_attacks,
n_attacks,
n_clients,
n_samples,
)
gen_pad_csv_scores(
os.path.join(outdir, "scores-eval.csv"),
mean_match,
mean_attacks,
n_attacks,
n_clients,
n_samples,
)
@common_options.metrics_command( @common_options.metrics_command(
...@@ -174,7 +247,7 @@ See also ``bob pad multi-metrics``. ...@@ -174,7 +247,7 @@ See also ``bob pad multi-metrics``.
@metrics_option() @metrics_option()
def metrics(ctx, scores, evaluation, regexps, regexp_column, metrics, **kwargs): def metrics(ctx, scores, evaluation, regexps, regexp_column, metrics, **kwargs):
load_fn = partial( load_fn = partial(
negatives_per_pai_and_positives, regexps=regexps, regexp_column=regexp_column split_csv_pad_per_pai, regexps=regexps, regexp_column=regexp_column
) )
process = figure.Metrics(ctx, scores, evaluation, load_fn, metrics) process = figure.Metrics(ctx, scores, evaluation, load_fn, metrics)
process.run() process.run()
...@@ -184,7 +257,7 @@ def metrics(ctx, scores, evaluation, regexps, regexp_column, metrics, **kwargs): ...@@ -184,7 +257,7 @@ def metrics(ctx, scores, evaluation, regexps, regexp_column, metrics, **kwargs):
common_options.ROC_HELP.format(score_format=SCORE_FORMAT, command="bob pad roc") common_options.ROC_HELP.format(score_format=SCORE_FORMAT, command="bob pad roc")
) )
def roc(ctx, scores, evaluation, **kwargs): def roc(ctx, scores, evaluation, **kwargs):
process = figure.Roc(ctx, scores, evaluation, load.split) process = figure.Roc(ctx, scores, evaluation, split_csv_pad)
process.run() process.run()
...@@ -192,7 +265,7 @@ def roc(ctx, scores, evaluation, **kwargs): ...@@ -192,7 +265,7 @@ def roc(ctx, scores, evaluation, **kwargs):
common_options.DET_HELP.format(score_format=SCORE_FORMAT, command="bob pad det") common_options.DET_HELP.format(score_format=SCORE_FORMAT, command="bob pad det")
) )
def det(ctx, scores, evaluation, **kwargs): def det(ctx, scores, evaluation, **kwargs):
process = figure.Det(ctx, scores, evaluation, load.split) process = figure.Det(ctx, scores, evaluation, split_csv_pad)
process.run() process.run()
...@@ -200,7 +273,7 @@ def det(ctx, scores, evaluation, **kwargs): ...@@ -200,7 +273,7 @@ def det(ctx, scores, evaluation, **kwargs):
common_options.EPC_HELP.format(score_format=SCORE_FORMAT, command="bob pad epc") common_options.EPC_HELP.format(score_format=SCORE_FORMAT, command="bob pad epc")
) )
def epc(ctx, scores, **kwargs): def epc(ctx, scores, **kwargs):
process = measure_figure.Epc(ctx, scores, True, load.split, hter="ACER") process = measure_figure.Epc(ctx, scores, True, split_csv_pad, hter="ACER")
process.run() process.run()
...@@ -208,7 +281,7 @@ def epc(ctx, scores, **kwargs): ...@@ -208,7 +281,7 @@ def epc(ctx, scores, **kwargs):
common_options.HIST_HELP.format(score_format=SCORE_FORMAT, command="bob pad hist") common_options.HIST_HELP.format(score_format=SCORE_FORMAT, command="bob pad hist")
) )
def hist(ctx, scores, evaluation, **kwargs): def hist(ctx, scores, evaluation, **kwargs):
process = figure.Hist(ctx, scores, evaluation, load.split) process = figure.Hist(ctx, scores, evaluation, split_csv_pad)
process.run() process.run()
...@@ -250,7 +323,7 @@ def multi_metrics( ...@@ -250,7 +323,7 @@ def multi_metrics(
): ):
ctx.meta["min_arg"] = protocols_number * (2 if evaluation else 1) ctx.meta["min_arg"] = protocols_number * (2 if evaluation else 1)
load_fn = partial( load_fn = partial(
negatives_per_pai_and_positives, regexps=regexps, regexp_column=regexp_column split_csv_pad_per_pai, regexps=regexps, regexp_column=regexp_column
) )
process = figure.MultiMetrics(ctx, scores, evaluation, load_fn, metrics) process = figure.MultiMetrics(ctx, scores, evaluation, load_fn, metrics)
process.run() process.run()
...@@ -7,6 +7,8 @@ from bob.extension.scripts.click_helper import ConfigCommand ...@@ -7,6 +7,8 @@ from bob.extension.scripts.click_helper import ConfigCommand
from bob.extension.scripts.click_helper import ResourceOption from bob.extension.scripts.click_helper import ResourceOption
from bob.extension.scripts.click_helper import verbosity_option from bob.extension.scripts.click_helper import verbosity_option
from bob.pipelines.distributed import dask_get_partition_size from bob.pipelines.distributed import dask_get_partition_size
from io import StringIO
import csv
@click.command( @click.command(
...@@ -71,6 +73,14 @@ from bob.pipelines.distributed import dask_get_partition_size ...@@ -71,6 +73,14 @@ from bob.pipelines.distributed import dask_get_partition_size
help="Saves scores (and checkpoints) in this folder.", help="Saves scores (and checkpoints) in this folder.",
cls=ResourceOption, cls=ResourceOption,
) )
@click.option(
"--csv-scores/--lst-scores",
"write_metadata_scores",
default=True,
help="Choose the score file format as 'csv' with additional metadata or 'lst' 4 "
"columns. Default: --csv-scores",
cls=ResourceOption,
)
@click.option( @click.option(
"--checkpoint", "--checkpoint",
"-c", "-c",
...@@ -108,6 +118,7 @@ def vanilla_pad( ...@@ -108,6 +118,7 @@ def vanilla_pad(
dask_client, dask_client,
groups, groups,
output, output,
write_metadata_scores,
checkpoint, checkpoint,
dask_partition_size, dask_partition_size,
dask_n_workers, dask_n_workers,
...@@ -131,6 +142,10 @@ def vanilla_pad( ...@@ -131,6 +142,10 @@ def vanilla_pad(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
log_parameters(logger) log_parameters(logger)
get_score_row = score_row_csv if write_metadata_scores else score_row_four_columns
output_file_ext = ".csv" if write_metadata_scores else ""
intermediate_file_ext = ".csv.gz" if write_metadata_scores else ".txt.gz"
os.makedirs(output, exist_ok=True) os.makedirs(output, exist_ok=True)
if checkpoint: if checkpoint:
...@@ -146,7 +161,7 @@ def vanilla_pad( ...@@ -146,7 +161,7 @@ def vanilla_pad(
predict_samples[group] = database.predict_samples(group=group) predict_samples[group] = database.predict_samples(group=group)
total_samples += len(predict_samples[group]) total_samples += len(predict_samples[group])
# Checking if the pipieline is dask-wrapped # Checking if the pipeline is dask-wrapped
first_step = pipeline[0] first_step = pipeline[0]
if not isinstance_nested(first_step, "estimator", DaskWrapper): if not isinstance_nested(first_step, "estimator", DaskWrapper):
...@@ -182,13 +197,13 @@ def vanilla_pad( ...@@ -182,13 +197,13 @@ def vanilla_pad(
logger.info(f"Running vanilla biometrics for group {group}") logger.info(f"Running vanilla biometrics for group {group}")
result = getattr(pipeline, decision_function)(predict_samples[group]) result = getattr(pipeline, decision_function)(predict_samples[group])
scores_path = os.path.join(output, f"scores-{group}") scores_path = os.path.join(output, f"scores-{group}{output_file_ext}")
if isinstance(result, dask.bag.core.Bag): if isinstance(result, dask.bag.core.Bag):
# write each partition into a zipped txt file # write each partition into a zipped txt file, one line per sample
result = result.map(pad_predicted_sample_to_score_line) result = result.map(get_score_row)
prefix, postfix = f"{output}/scores/scores-{group}-", ".txt.gz" prefix, postfix = f"{output}/scores/scores-{group}-", intermediate_file_ext
pattern = f"{prefix}*{postfix}" pattern = f"{prefix}*{postfix}"
os.makedirs(os.path.dirname(prefix), exist_ok=True) os.makedirs(os.path.dirname(prefix), exist_ok=True)
logger.info("Writing bag results into files ...") logger.info("Writing bag results into files ...")
...@@ -198,29 +213,52 @@ def vanilla_pad( ...@@ -198,29 +213,52 @@ def vanilla_pad(
) )
with open(scores_path, "w") as f: with open(scores_path, "w") as f:
csv_writer, header = None, None
# concatenate scores into one score file # concatenate scores into one score file
for path in sorted( for path in sorted(
glob(pattern), glob(pattern),
key=lambda l: int(l.replace(prefix, "").replace(postfix, "")), key=lambda l: int(l.replace(prefix, "").replace(postfix, "")),
): ):
with gzip.open(path, "rt") as f2: with gzip.open(path, "rt") as f2:
f.write(f2.read()) if write_metadata_scores:
if csv_writer is None:
# Retrieve the header from one of the _header fields
tmp_reader = csv.reader(f2)
# Reconstruct a list from the str representation
header = next(tmp_reader)[-1].strip("][").split(", ")
header = [s.strip("' ") for s in header]
csv_writer = csv.DictWriter(f, fieldnames=header)
csv_writer.writeheader()
f2.seek(0, 0)
# There is no header in the intermediary files, specify it
csv_reader = csv.DictReader(
f2, fieldnames=header + ["_header"]
)
for row in csv_reader:
# Write each element of the row, except `_header`
csv_writer.writerow(
{k: row[k] for k in row.keys() if k != "_header"}
)
else:
f.write(f2.read())
# delete intermediate score files # delete intermediate score files
os.remove(path) os.remove(path)
else: else:
with open(scores_path, "w") as f: with open(scores_path, "w") as f:
if write_metadata_scores:
csv.DictWriter(
f, fieldnames=_get_csv_columns(result[0]).keys()
).writeheader()
for sample in result: for sample in result:
f.write(pad_predicted_sample_to_score_line(sample, endl="\n")) f.write(get_score_row(sample, endl="\n"))
def pad_predicted_sample_to_score_line(sample, endl=""): def score_row_four_columns(sample, endl=""):
claimed_id, test_label, score = sample.subject, sample.key, sample.data claimed_id, test_label, score = sample.subject, sample.key, sample.data
# # use the model_label field to indicate frame number # # use the model_label field to indicate frame number
# model_label = None # model_label = getattr(sample, "frame_id", None)
# if hasattr(sample, "frame_id"):
# model_label = sample.frame_id
real_id = claimed_id if sample.is_bonafide else sample.attack_type real_id = claimed_id if sample.is_bonafide else sample.attack_type
...@@ -229,3 +267,58 @@ def pad_predicted_sample_to_score_line(sample, endl=""): ...@@ -229,3 +267,58 @@ def pad_predicted_sample_to_score_line(sample, endl=""):
return f"{claimed_id} {real_id} {test_label} {score}{endl}" return f"{claimed_id} {real_id} {test_label} {score}{endl}"
# return f"{claimed_id} {model_label} {real_id} {test_label} {score}{endl}"