Commit 0df6df77 authored by Yannick DAYER's avatar Yannick DAYER

Writing scores a CSV files with metadata

parent 743a5c71
......@@ -7,6 +7,8 @@ from bob.extension.scripts.click_helper import ConfigCommand
from bob.extension.scripts.click_helper import ResourceOption
from bob.extension.scripts.click_helper import verbosity_option
from bob.pipelines.distributed import dask_get_partition_size
from io import StringIO
import csv
@click.command(
......@@ -139,6 +141,10 @@ def vanilla_pad(
logger = logging.getLogger(__name__)
log_parameters(logger)
get_score_row = score_row_csv if write_metadata_scores else score_row_four_columns
output_file_ext = ".csv" if write_metadata_scores else ""
intermediate_file_ext = ".csv.gz" if write_metadata_scores else ".txt.gz"
os.makedirs(output, exist_ok=True)
if checkpoint:
......@@ -154,7 +160,7 @@ def vanilla_pad(
predict_samples[group] = database.predict_samples(group=group)
total_samples += len(predict_samples[group])
# Checking if the pipieline is dask-wrapped
# Checking if the pipeline is dask-wrapped
first_step = pipeline[0]
if not isinstance_nested(first_step, "estimator", DaskWrapper):
......@@ -190,13 +196,13 @@ def vanilla_pad(
logger.info(f"Running vanilla biometrics for group {group}")
result = getattr(pipeline, decision_function)(predict_samples[group])
scores_path = os.path.join(output, f"scores-{group}")
scores_path = os.path.join(output, f"scores-{group}{output_file_ext}")
if isinstance(result, dask.bag.core.Bag):
# write each partition into a zipped txt file
result = result.map(pad_predicted_sample_to_score_line)
prefix, postfix = f"{output}/scores/scores-{group}-", ".txt.gz"
# write each partition into a zipped txt file, one line per sample
result = result.map(get_score_row)
prefix, postfix = f"{output}/scores/scores-{group}-", intermediate_file_ext
pattern = f"{prefix}*{postfix}"
os.makedirs(os.path.dirname(prefix), exist_ok=True)
logger.info("Writing bag results into files ...")
......@@ -206,29 +212,52 @@ def vanilla_pad(
)
with open(scores_path, "w") as f:
csv_writer, header = None, None
# concatenate scores into one score file
for path in sorted(
glob(pattern),
key=lambda l: int(l.replace(prefix, "").replace(postfix, "")),
):
with gzip.open(path, "rt") as f2:
f.write(f2.read())
if write_metadata_scores:
if csv_writer is None:
# Retrieve the header from one of the _header fields
tmp_reader = csv.reader(f2)
# Reconstruct a list from the str representation
header = next(tmp_reader)[-1].strip("][").split(", ")
header = [s.strip("' ") for s in header]
csv_writer = csv.DictWriter(f, fieldnames=header)
csv_writer.writeheader()
f2.seek(0, 0)
# There is no header in the intermediary files, specify it
csv_reader = csv.DictReader(
f2, fieldnames=header + ["_header"]
)
for row in csv_reader:
# Write each element of the row, except `_header`
csv_writer.writerow(
{k: row[k] for k in row.keys() if k != "_header"}
)
else:
f.write(f2.read())
# delete intermediate score files
os.remove(path)
else:
with open(scores_path, "w") as f:
if write_metadata_scores:
csv.DictWriter(
f, fieldnames=_get_csv_columns(result[0]).keys()
).writeheader()
for sample in result:
f.write(pad_predicted_sample_to_score_line(sample, endl="\n"))
f.write(get_score_row(sample, endl="\n"))
def pad_predicted_sample_to_score_line(sample, endl=""):
def score_row_four_columns(sample, endl=""):
claimed_id, test_label, score = sample.subject, sample.key, sample.data
# # use the model_label field to indicate frame number
# model_label = None
# if hasattr(sample, "frame_id"):
# model_label = sample.frame_id
# model_label = getattr(sample, "frame_id", None)
real_id = claimed_id if sample.is_bonafide else sample.attack_type
......@@ -237,3 +266,58 @@ def pad_predicted_sample_to_score_line(sample, endl=""):
return f"{claimed_id} {real_id} {test_label} {score}{endl}"
# return f"{claimed_id} {model_label} {real_id} {test_label} {score}{endl}"
def _get_csv_columns(sample):
"""Returns a dict of {csv_column_name: sample_attr_name} given a sample."""
# Mandatory columns and their corresponding fields
columns_attr = {
"claimed_id": "subject",
"test_label": "key",
"is_bonafide": "is_bonafide",
"attack_type": "attack_type",
"score": "data",
}
# Preventing duplicates and unwanted data
ignored_fields = list(columns_attr.values()) + ["annotations"]
# Retrieving custom metadata attribute names
metadata_fields = [
k
for k in sample.__dict__.keys()
if not k.startswith("_") and k not in ignored_fields
]
for field in metadata_fields:
columns_attr[field] = field
return columns_attr
def score_row_csv(sample, endl=""):
"""Returns a str representing one row of a CSV for the sample.
If endl is empty, it is assumed that the row will be stored in a temporary file
without header, thus a `_header` column is added at the end, containing the header
as a list. This field can be used to reconstruct the final file.
"""
columns_fields = _get_csv_columns(sample)
string_stream = StringIO()
csv_writer = csv.DictWriter(
string_stream,
fieldnames=list(columns_fields.keys()) + (["_header"] if endl == "" else []),
)
row_values = {
col: getattr(sample, attr, None) for col, attr in columns_fields.items()
}
if row_values["score"] is None:
row_values["score"] = "nan"
# Add a `_header` field to store the current CSV header (used in the dask Bag case)
if endl == "":
row_values["_header"] = list(columns_fields.keys())
csv_writer.writerow(row_values)
out_string = string_stream.getvalue()
if endl == "":
return out_string.rstrip()
else:
return out_string
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment