Commit 903af807 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Refactor the vulnerability SampleLoader

parent f37952e0
Pipeline #51804 passed with stage
in 11 minutes and 50 seconds
......@@ -153,7 +153,7 @@ class CSVToSampleLoaderBiometrics(CSVToSampleLoader):
class CSVToSampleLoaderVulnerability(CSVToSampleLoaderBiometrics):
"""
Class that converts the lines of a CSV file, like the one below to
Class that converts the lines of a CSV file like the one below to
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet` for the
vulnerability analysis framework.
......@@ -170,7 +170,8 @@ class CSVToSampleLoaderVulnerability(CSVToSampleLoaderBiometrics):
This loader creates probe samples with a ``references`` fields so that the attacks
probes are only compares against the model targeted by the attack. The zero-effort
impostor probes are compared against all the models.
This requires the database ``fetch_probes`` arguments to be set to ``False``.
This requires the database's ``fetch_probes`` argument to be set to ``False``.
Parameters
----------
......@@ -198,35 +199,42 @@ class CSVToSampleLoaderVulnerability(CSVToSampleLoaderBiometrics):
data_loader=data_loader,
extension=extension,
dataset_original_directory=dataset_original_directory,
reference_id_equal_subject_id=reference_id_equal_subject_id
reference_id_equal_subject_id=reference_id_equal_subject_id,
)
self.all_references = []
def convert_row_to_sample(self, row, header):
fields = {str(h).lower():r for h, r in zip(header, row)}
fields = {str(h).lower(): r for h, r in zip(header, row)}
if fields["reference_id"] not in self.all_references:
self.all_references.append(fields["reference_id"])
if self.reference_id_equal_subject_id:
fields["subject_id"] = fields["reference_id"]
else:
if "subject_id" not in fields:
raise ValueError(f"`subject_id` not available in {header}")
# If an attack, only compare to own reference
probe_references = []
if fields.get("attack_type", None) is None:
probe_references.extend(self.all_references)
else:
probe_references.append(fields["reference_id"])
kwargs = {k: fields[k] for k in fields.keys() - {"id","should_flip"}}
# Fields that will not be added to the Sample
fields_to_ignore = {
"id",
}
kwargs = {k: fields[k] for k in fields.keys() - fields_to_ignore}
return DelayedSample(
functools.partial(
self.data_loader,
os.path.join(self.dataset_original_directory, fields["path"] + self.extension),
os.path.join(
self.dataset_original_directory, fields["path"] + self.extension
),
),
key=fields["path"],
reference_id=fields["reference_id"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment