Commit e2d8ab7c authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira Committed by Amir MOHAMMADI

New database interface for PAD

parent ee28a9fe
......@@ -3,6 +3,7 @@ from .database import PadDatabase
from .filelist.query import FileListPadDatabase
from .filelist.models import Client
from .PadBioFileDB import HighBioDatabase, HighPadDatabase
from .csv_dataset import CSVPADDataset, CSVToSampleLoader, LSTToSampleLoader
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
......@@ -21,6 +22,14 @@ def __appropriate__(*args):
__appropriate__(
PadFile, PadDatabase, FileListPadDatabase, Client, HighBioDatabase, HighPadDatabase
PadFile,
PadDatabase,
FileListPadDatabase,
Client,
HighBioDatabase,
HighPadDatabase,
CSVPADDataset,
CSVToSampleLoader,
LSTToSampleLoader,
)
__all__ = [_ for _ in dir() if not _.startswith("_")]
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
from bob.pad.base.pipelines.vanilla_pad.abstract_classes import Database
import csv
from bob.pipelines.datasets.sample_loaders import CSVBaseSampleLoader
from bob.extension.download import search_file
from bob.pipelines import DelayedSample
import bob.io.base
import os
import functools
class CSVToSampleLoader(CSVBaseSampleLoader):
"""
Simple mechanism that converts the lines of a CSV file to
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
Each CSV line of a PAD datasets should have at least a PATH and a SUBJECT id like
in the example below:
```
.. code-block:: text
PATH,SUBJECT
path_1,subject_1
```
"""
def check_header(self, header):
"""
A header should have at least "subject" AND "PATH"
"""
header = [h.lower() for h in header]
if not "subject" in header:
raise ValueError("The field `subject` is not available in your dataset.")
if not "path" in header:
raise ValueError("The field `path` is not available in your dataset.")
def __call__(self, f, is_bonafide=True):
f.seek(0)
reader = csv.reader(f)
header = next(reader)
self.check_header(header)
return [
self.convert_row_to_sample(row, header, is_bonafide=is_bonafide)
for row in reader
]
def convert_row_to_sample(self, row, header=None, is_bonafide=True):
path = str(row[0])
subject = str(row[1])
kwargs = dict([[str(h).lower(), r] for h, r in zip(header[2:], row[2:])])
if self.metadata_loader is not None:
metadata = self.metadata_loader(row, header=header, is_bonafide=is_bonafide)
kwargs.update(metadata)
return DelayedSample(
functools.partial(
self.data_loader,
os.path.join(self.dataset_original_directory, path + self.extension),
),
key=path,
subject=subject,
is_bonafide=is_bonafide,
**kwargs,
)
class LSTToSampleLoader(CSVBaseSampleLoader):
"""
Simple mechanism that converts the lines of a LST file to
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
"""
def __call__(self, f, is_bonafide=True):
f.seek(0)
reader = csv.reader(f, delimiter=" ")
samples = []
for row in reader:
if row[0][0] == "#":
continue
samples.append(self.convert_row_to_sample(row, is_bonafide=is_bonafide))
return samples
def convert_row_to_sample(self, row, header=None, is_bonafide=True):
path = str(row[0])
subject = str(row[1])
attack_type = None
if len(row) == 3:
attack_type = str(row[2])
kwargs = dict()
if self.metadata_loader is not None:
metadata = self.metadata_loader(row, header=header)
kwargs.update(metadata)
return DelayedSample(
functools.partial(
self.data_loader,
os.path.join(self.dataset_original_directory, path + self.extension),
),
key=path,
subject=subject,
is_bonafide=is_bonafide,
attack_type=attack_type,
**kwargs,
)
class CSVPADDataset(Database):
"""
Generic filelist dataset for PAD experiments.
To create a new dataset, you need to provide a directory structure similar to the one below:
.. code-block:: text
my_dataset/
my_dataset/my_protocol/train/for_real.csv
my_dataset/my_protocol/train/for_attack.csv
my_dataset/my_protocol/dev/for_real.csv
my_dataset/my_protocol/dev/for_attack.csv
my_dataset/my_protocol/eval/for_real.csv
my_dataset/my_protocol/eval/for_attack.csv
These csv files should contain in each row i-) the path to raw data and
ii-) and an identifier to the subject in the image (subject).
The structure of each CSV file should be as below:
.. code-block:: text
PATH,SUBJECT
path_1,subject_1
path_2,subject_2
path_i,subject_j
...
You might want to ship metadata within your Samples (e.g gender, age, annotations, ...)
To do so is simple, just do as below:
.. code-block:: text
PATH,SUBJECT,TYPE_OF_ATTACK,GENDER,AGE
path_1,subject_1,A,B,C
path_2,subject_2,A,B,1
path_i,subject_j,2,3,4
...
The files `my_dataset/my_protocol/eval/for_real.csv` and `my_dataset/my_protocol/eval/for_attack.csv`
are optional and it is used in case a protocol contains data for evaluation.
Finally, the content of the files `my_dataset/my_protocol/train/for_real.csv` and `my_dataset/my_protocol/train/for_attack.csv` are used in the case a protocol
contains data for training.
Parameters
----------
dataset_path: str
Absolute path or a tarball of the dataset protocol description.
protocol_na,e: str
The name of the protocol
csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
Base class that whose objective is to generate :any:`bob.pipelines.Sample`
and/or :any:`bob.pipelines.SampleSet` from csv rows
"""
def __init__(
self,
dataset_protocol_path,
protocol_name,
csv_to_sample_loader=CSVToSampleLoader(
data_loader=bob.io.base.load,
metadata_loader=None,
dataset_original_directory="",
extension="",
),
):
self.dataset_protocol_path = dataset_protocol_path
self.protocol_name = protocol_name
def get_paths():
if not os.path.exists(dataset_protocol_path):
raise ValueError(f"The path `{dataset_protocol_path}` was not found")
# Here we are handling the legacy
train_real_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "train", "for_real.lst"),
os.path.join(protocol_name, "train", "for_real.csv"),
],
)
train_attack_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "train", "for_attack.lst"),
os.path.join(protocol_name, "train", "for_attack.csv"),
],
)
dev_real_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "dev", "for_real.lst"),
os.path.join(protocol_name, "dev", "for_real.csv"),
],
)
dev_attack_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "dev", "for_attack.lst"),
os.path.join(protocol_name, "dev", "for_attack.csv"),
],
)
eval_real_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "eval", "for_real.lst"),
os.path.join(protocol_name, "eval", "for_real.csv"),
],
)
eval_attack_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "eval", "for_attack.lst"),
os.path.join(protocol_name, "eval", "for_attack.csv"),
],
)
# The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`
# Dev
if dev_real_csv is None:
raise ValueError(
f"The file `{dev_real_csv}` is required and it was not found"
)
if dev_attack_csv is None:
raise ValueError(
f"The file `{dev_attack_csv}` is required and it was not found"
)
return (
train_real_csv,
train_attack_csv,
dev_real_csv,
dev_attack_csv,
eval_real_csv,
eval_attack_csv,
)
(
self.train_real_csv,
self.train_attack_csv,
self.dev_real_csv,
self.dev_attack_csv,
self.eval_real_csv,
self.eval_attack_csv,
) = get_paths()
def get_dict_cache():
cache = dict()
cache["train_real_csv"] = None
cache["train_attack_csv"] = None
cache["dev_real_csv"] = None
cache["dev_attack_csv"] = None
cache["eval_real_csv"] = None
cache["eval_attack_csv"] = None
return cache
self.cache = get_dict_cache()
self.csv_to_sample_loader = csv_to_sample_loader
def _load_samples(self, cache_key, filepointer, is_bonafide):
self.cache[cache_key] = (
self.csv_to_sample_loader(filepointer, is_bonafide)
if self.cache[cache_key] is None
else self.cache[cache_key]
)
return self.cache[cache_key]
def fit_samples(self):
return self._load_samples(
"train_real_csv", self.train_real_csv, is_bonafide=True
) + self._load_samples(
"train_attack_csv", self.train_attack_csv, is_bonafide=False
)
def predict_samples(self, group="dev"):
if group == "dev":
return self._load_samples(
"dev_real_csv", self.dev_real_csv, is_bonafide=True
) + self._load_samples(
"dev_attack_csv", self.dev_attack_csv, is_bonafide=False
)
else:
return self._load_samples(
"eval_real_csv", self.eval_real_csv, is_bonafide=True
) + self._load_samples(
"eval_attack_csv", self.eval_attack_csv, is_bonafide=False
)
PATH,SUBJECT
data/attack10,SUBJECT_10
data/attack20,SUBJECT_10
data/attack30,SUBJECT_20
\ No newline at end of file
PATH,SUBJECT
data/real10,SUBJECT_10
data/real20,SUBJECT_20
\ No newline at end of file
PATH,SUBJECT
data/attack100,SUBJECT_100
data/attack200,SUBJECT_100
data/attack300,SUBJECT_200
data/attack400,SUBJECT_200
\ No newline at end of file
PATH,SUBJECT
data/real100,SUBJECT_100
data/real200,SUBJECT_200
data/real300,SUBJECT_200
\ No newline at end of file
PATH,SUBJECT
data/attack1,SUBJECT_1
data/attack2,SUBJECT_1
data/attack3,SUBJECT_2
\ No newline at end of file
PATH,SUBJECT
data/real1,SUBJECT_1
data/real2,SUBJECT_2
\ No newline at end of file
......@@ -23,51 +23,154 @@ Tests for the PAD Filelist database.
import os
import bob.io.base.test_utils
from bob.pad.base.database import FileListPadDatabase
from bob.pad.base.database import FileListPadDatabase, CSVPADDataset, LSTToSampleLoader
example_dir = os.path.realpath(bob.io.base.test_utils.datafile('.', __name__, 'data/example_filelist'))
example_dir = os.path.realpath(
bob.io.base.test_utils.datafile(".", __name__, "data/example_filelist")
)
csv_example_dir = os.path.realpath(
bob.io.base.test_utils.datafile(".", __name__, "data/csv_dataset")
)
csv_example_tarball = os.path.realpath(
bob.io.base.test_utils.datafile(".", __name__, "data/csv_dataset.tar.gz")
)
def test_query():
db = FileListPadDatabase(example_dir, 'test_padfilelist')
db = FileListPadDatabase(example_dir, "test_padfilelist")
assert len(db.groups()) == 3 # 3 groups (dev, eval, train)
print(db.client_ids())
# 5 client ids for real data of train, dev and eval sets (ignore all ids that are in attacks only)
assert len(db.client_ids()) == 5
assert len(db.client_ids(groups='train')) == 2 # 2 client ids for train
assert len(db.client_ids(groups='dev')) == 2 # 2 client ids for dev
assert len(db.client_ids(groups='eval')) == 1 # 2 client ids for eval
assert len(db.client_ids(groups="train")) == 2 # 2 client ids for train
assert len(db.client_ids(groups="dev")) == 2 # 2 client ids for dev
assert len(db.client_ids(groups="eval")) == 1 # 2 client ids for eval
assert len(db.objects(groups='train')) == 3 # 3 samples in the train set
assert len(db.objects(groups="train")) == 3 # 3 samples in the train set
assert len(db.objects(groups='dev', purposes='real')) == 2 # 2 samples of real data in the dev set
assert len(db.objects(groups='dev', purposes='attack')) == 1 # 1 attack in the dev set
assert (
len(db.objects(groups="dev", purposes="real")) == 2
) # 2 samples of real data in the dev set
assert (
len(db.objects(groups="dev", purposes="attack")) == 1
) # 1 attack in the dev set
def test_query_protocol():
db = FileListPadDatabase(os.path.dirname(example_dir), 'test_padfilelist')
p = 'example_filelist'
db = FileListPadDatabase(os.path.dirname(example_dir), "test_padfilelist")
p = "example_filelist"
assert len(db.groups(protocol=p)) == 3 # 3 groups (dev, eval, train)
assert len(db.client_ids(protocol=p)) == 5 # 6 client ids for train, dev and eval
assert len(db.client_ids(groups='train', protocol=p)) == 2 # 2 client ids for train
assert len(db.client_ids(groups='dev', protocol=p)) == 2 # 2 client ids for dev
assert len(db.client_ids(groups='eval', protocol=p)) == 1 # 2 client ids for eval
assert len(db.client_ids(groups="train", protocol=p)) == 2 # 2 client ids for train
assert len(db.client_ids(groups="dev", protocol=p)) == 2 # 2 client ids for dev
assert len(db.client_ids(groups="eval", protocol=p)) == 1 # 2 client ids for eval
assert len(db.objects(groups='train', protocol=p)) == 3 # 3 samples in the train set
assert (
len(db.objects(groups="train", protocol=p)) == 3
) # 3 samples in the train set
assert len(db.objects(groups='dev', purposes='real', protocol=p)) == 2 # 2 samples of real data in the dev set
assert len(db.objects(groups='dev', purposes='attack', protocol=p)) == 1 # 1 attack in the dev set
assert (
len(db.objects(groups="dev", purposes="real", protocol=p)) == 2
) # 2 samples of real data in the dev set
assert (
len(db.objects(groups="dev", purposes="attack", protocol=p)) == 1
) # 1 attack in the dev set
def test_driver_api():
from bob.db.base.script.dbmanage import main
assert main(('pad_filelist clients --list-directory=%s --self-test' % example_dir).split()) == 0
assert main(('pad_filelist dumplist --list-directory=%s --self-test' % example_dir).split()) == 0
assert main(('pad_filelist dumplist --list-directory=%s --purpose=real --group=dev --self-test' %
example_dir).split()) == 0
assert main(('pad_filelist checkfiles --list-directory=%s --self-test' % example_dir).split()) == 0
assert (
main(
(
"pad_filelist clients --list-directory=%s --self-test" % example_dir
).split()
)
== 0
)
assert (
main(
(
"pad_filelist dumplist --list-directory=%s --self-test" % example_dir
).split()
)
== 0
)
assert (
main(
(
"pad_filelist dumplist --list-directory=%s --purpose=real --group=dev --self-test"
% example_dir
).split()
)
== 0
)
assert (
main(
(
"pad_filelist checkfiles --list-directory=%s --self-test" % example_dir
).split()
)
== 0
)
def test_csv_dataset():
def run(path):
dataset = CSVPADDataset(path, "protocol1")
# Train
assert len(dataset.fit_samples()) == 5
# 2 out of 5 are bonafides
assert sum([s.is_bonafide for s in dataset.fit_samples()]) == 2
# DEV
assert len(dataset.predict_samples()) == 5
# 2 out of 5 are bonafides
assert sum([s.is_bonafide for s in dataset.predict_samples()]) == 2
# EVAL
assert len(dataset.predict_samples(group="eval")) == 7
# 3 out of 5 are bonafides
assert sum([s.is_bonafide for s in dataset.predict_samples(group="eval")]) == 3
run(csv_example_dir)
run(csv_example_tarball)
def test_csv_dataset_lst():
dataset = CSVPADDataset(
example_dir,
"",
csv_to_sample_loader=LSTToSampleLoader(
data_loader=bob.io.base.load,
metadata_loader=None,
dataset_original_directory="",
extension="",
),
)
# Train
assert len(dataset.fit_samples()) == 3
# 2 out of 3 are bonafides
assert sum([s.is_bonafide for s in dataset.fit_samples()]) == 2
# DEV
assert len(dataset.predict_samples()) == 3
# 2 out of 3 are bonafides
assert sum([s.is_bonafide for s in dataset.predict_samples()]) == 2
# EVAL
assert len(dataset.predict_samples(group="eval")) == 2
# 1 out of 2 are bonafides
assert sum([s.is_bonafide for s in dataset.predict_samples(group="eval")]) == 1
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment