Skip to content
Snippets Groups Projects

[csv_dataset] There is no need for CSVToPADSamples

1 file
+ 21
25
Compare changes
  • Side-by-side
  • Inline
+ 21
25
@@ -4,32 +4,28 @@
from bob.bio.base.database.legacy import check_parameters_for_validity
from bob.pad.base.pipelines.abstract_classes import Database
from bob.pipelines.datasets import CSVToSamples, FileListDatabase
from bob.pipelines.datasets import FileListDatabase
class CSVToPADSamples(CSVToSamples):
"""Converts a csv file to a list of PAD samples"""
def __iter__(self):
for sample in super().__iter__():
if not hasattr(sample, "subject"):
raise RuntimeError(
"PAD samples should contain a `subject` attribute which "
"reveals the identifies the person from whom the sample is created."
)
if not hasattr(sample, "attack_type"):
raise RuntimeError(
"PAD samples should contain a `attack_type` attribute which "
"should be '' for bona fide samples and something like "
"print, replay, mask, etc. for attacks. This attribute is "
"considered the PAI type of each attack is used to compute APCER."
)
if sample.attack_type == "":
sample.attack_type = None
sample.is_bonafide = sample.attack_type is None
if not hasattr(sample, "key"):
sample.key = sample.filename
yield sample
def validate_pad_sample(sample):
if not hasattr(sample, "subject"):
raise RuntimeError(
"PAD samples should contain a `subject` attribute which "
"reveals the identifies the person from whom the sample is created."
)
if not hasattr(sample, "attack_type"):
raise RuntimeError(
"PAD samples should contain a `attack_type` attribute which "
"should be '' for bona fide samples and something like "
"print, replay, mask, etc. for attacks. This attribute is "
"considered the PAI type of each attack is used to compute APCER."
)
if sample.attack_type == "":
sample.attack_type = None
sample.is_bonafide = sample.attack_type is None
if not hasattr(sample, "key"):
sample.key = sample.filename
return sample
class FileListPadDatabase(Database, FileListDatabase):
@@ -45,7 +41,6 @@ class FileListPadDatabase(Database, FileListDatabase):
super().__init__(
dataset_protocols_path=dataset_protocols_path,
protocol=protocol,
reader_cls=CSVToPADSamples,
transformer=transformer,
**kwargs,
)
@@ -69,6 +64,7 @@ class FileListPadDatabase(Database, FileListDatabase):
(not s.is_bonafide) and "attack" in purposes
)
results = [validate_pad_sample(sample) for sample in results]
results = list(filter(_filter, results))
return results
Loading