Commit c1227609 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Implement a better file csv interface

parent 932cdf39
Pipeline #48601 passed with stage
in 7 minutes and 2 seconds
from .file import PadFile
from .database import PadDatabase
from .filelist.query import FileListPadDatabase
from .filelist.models import Client
from .csv_dataset import FileListPadDatabase
from .PadBioFileDB import HighBioDatabase, HighPadDatabase
from .csv_dataset import CSVPADDataset, CSVToSampleLoader, LSTToSampleLoader
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
......@@ -25,11 +23,7 @@ __appropriate__(
PadFile,
PadDatabase,
FileListPadDatabase,
Client,
HighBioDatabase,
HighPadDatabase,
CSVPADDataset,
CSVToSampleLoader,
LSTToSampleLoader,
)
__all__ = [_ for _ in dir() if not _.startswith("_")]
......@@ -2,324 +2,73 @@
# vim: set fileencoding=utf-8 :
from bob.db.base.utils import check_parameters_for_validity
from bob.pad.base.pipelines.vanilla_pad.abstract_classes import Database
import csv
from bob.pipelines.datasets import FileListDatabase, CSVToSamples
from bob.pipelines.datasets.sample_loaders import CSVBaseSampleLoader
from bob.extension.download import search_file
class CSVToPADSamples(CSVToSamples):
"""Converts a csv file to a list of PAD samples"""
from bob.pipelines import DelayedSample
import bob.io.base
import os
import functools
class CSVToSampleLoader(CSVBaseSampleLoader):
"""
Simple mechanism that converts the lines of a CSV file to
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
Each CSV line of a PAD datasets should have at least a PATH and a SUBJECT id like
in the example below:
.. code-block:: text
PATH,SUBJECT
path_1,subject_1
"""
def check_header(self, header):
"""
A header should have at least "subject" AND "PATH"
"""
header = [h.lower() for h in header]
if not "subject" in header:
raise ValueError("The field `subject` is not available in your dataset.")
if not "path" in header:
raise ValueError("The field `path` is not available in your dataset.")
def __call__(self, f, is_bonafide=True):
f.seek(0)
reader = csv.reader(f)
header = next(reader)
self.check_header(header)
return [
self.convert_row_to_sample(row, header, is_bonafide=is_bonafide)
for row in reader
]
def convert_row_to_sample(self, row, header=None, is_bonafide=True):
path = str(row[0])
subject = str(row[1])
kwargs = dict([[str(h).lower(), r] for h, r in zip(header[2:], row[2:])])
if self.metadata_loader is not None:
metadata = self.metadata_loader(row, header=header)
kwargs.update(metadata)
return DelayedSample(
functools.partial(
self.data_loader,
os.path.join(self.dataset_original_directory, path + self.extension),
),
key=path,
subject=subject,
is_bonafide=is_bonafide,
**kwargs,
)
class LSTToSampleLoader(CSVBaseSampleLoader):
"""
Simple mechanism that converts the lines of a LST file to
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
"""
def __call__(self, f, is_bonafide=True):
f.seek(0)
reader = csv.reader(f, delimiter=" ")
samples = []
for row in reader:
if row[0][0] == "#":
continue
samples.append(self.convert_row_to_sample(row, is_bonafide=is_bonafide))
return samples
def convert_row_to_sample(self, row, header=None, is_bonafide=True):
path = str(row[0])
subject = str(row[1])
attack_type = None
if len(row) == 3:
attack_type = str(row[2])
kwargs = dict()
if self.metadata_loader is not None:
metadata = self.metadata_loader(row, header=header)
kwargs.update(metadata)
return DelayedSample(
functools.partial(
self.data_loader,
os.path.join(self.dataset_original_directory, path + self.extension),
),
key=path,
subject=subject,
is_bonafide=is_bonafide,
attack_type=attack_type,
**kwargs,
)
class CSVPADDataset(Database):
"""
Generic filelist dataset for PAD experiments.
To create a new dataset, you need to provide a directory structure similar to the one below:
.. code-block:: text
my_dataset/
my_dataset/my_protocol/train/for_real.csv
my_dataset/my_protocol/train/for_attack.csv
my_dataset/my_protocol/dev/for_real.csv
my_dataset/my_protocol/dev/for_attack.csv
my_dataset/my_protocol/eval/for_real.csv
my_dataset/my_protocol/eval/for_attack.csv
These csv files should contain in each row i-) the path to raw data and
ii-) and an identifier to the subject in the image (subject).
The structure of each CSV file should be as below:
.. code-block:: text
PATH,SUBJECT
path_1,subject_1
path_2,subject_2
path_i,subject_j
...
You might want to ship metadata within your Samples (e.g gender, age, annotations, ...)
To do so is simple, just do as below:
.. code-block:: text
PATH,SUBJECT,TYPE_OF_ATTACK,GENDER,AGE
path_1,subject_1,A,B,C
path_2,subject_2,A,B,1
path_i,subject_j,2,3,4
...
The files `my_dataset/my_protocol/eval/for_real.csv` and `my_dataset/my_protocol/eval/for_attack.csv`
are optional and it is used in case a protocol contains data for evaluation.
Finally, the content of the files `my_dataset/my_protocol/train/for_real.csv` and `my_dataset/my_protocol/train/for_attack.csv` are used in the case a protocol
contains data for training.
Parameters
----------
dataset_path: str
Absolute path or a tarball of the dataset protocol description.
protocol_na,e: str
The name of the protocol
def __iter__(self):
for sample in super().__iter__():
if not hasattr(sample, "subject"):
raise RuntimeError(
"PAD samples should contain a `subject` attribute which "
"reveals the identifies the person from whom the sample is created."
)
if not hasattr(sample, "attack_type"):
raise RuntimeError(
"PAD samples should contain a `attack_type` attribute which "
"should be '' for bona fide samples and something like "
"print, replay, mask, etc. for attacks. This attribute is "
"considered the PAI type of each attack is used to compute APCER."
)
if sample.attack_type == "":
sample.attack_type = None
sample.is_bonafide = sample.attack_type is None
if not hasattr(sample, "key"):
sample.key = sample.filename
yield sample
csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
Base class that whose objective is to generate :any:`bob.pipelines.Sample`
and/or :any:`bob.pipelines.SampleSet` from csv rows
"""
class FileListPadDatabase(Database, FileListDatabase):
"""A PAD database interface from CSV files."""
def __init__(
self,
dataset_protocol_path,
protocol_name,
csv_to_sample_loader=CSVToSampleLoader(
data_loader=bob.io.base.load,
metadata_loader=None,
dataset_original_directory="",
extension="",
),
dataset_protocols_path,
protocol,
transformer=None,
**kwargs,
):
self.dataset_protocol_path = dataset_protocol_path
self.protocol_name = protocol_name
def get_paths():
if not os.path.exists(dataset_protocol_path):
raise ValueError(f"The path `{dataset_protocol_path}` was not found")
# Here we are handling the legacy
train_real_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "train", "for_real.lst"),
os.path.join(protocol_name, "train", "for_real.csv"),
],
)
train_attack_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "train", "for_attack.lst"),
os.path.join(protocol_name, "train", "for_attack.csv"),
],
)
dev_real_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "dev", "for_real.lst"),
os.path.join(protocol_name, "dev", "for_real.csv"),
],
)
dev_attack_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "dev", "for_attack.lst"),
os.path.join(protocol_name, "dev", "for_attack.csv"),
],
)
eval_real_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "eval", "for_real.lst"),
os.path.join(protocol_name, "eval", "for_real.csv"),
],
)
eval_attack_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "eval", "for_attack.lst"),
os.path.join(protocol_name, "eval", "for_attack.csv"),
],
)
super().__init__(
dataset_protocols_path=dataset_protocols_path,
protocol=protocol,
reader_cls=CSVToPADSamples,
transformer=transformer,
**kwargs,
)
# The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`
def purposes(self):
return ("real", "attack")
# Dev
if dev_real_csv is None:
raise ValueError(
f"The file `{dev_real_csv}` is required and it was not found"
)
if dev_attack_csv is None:
raise ValueError(
f"The file `{dev_attack_csv}` is required and it was not found"
)
def samples(self, groups=None, purposes=None):
results = super().samples(groups=groups)
purposes = check_parameters_for_validity(
purposes, "purposes", self.purposes(), self.purposes()
)
return (
train_real_csv,
train_attack_csv,
dev_real_csv,
dev_attack_csv,
eval_real_csv,
eval_attack_csv,
def _filter(s):
return (s.is_bonafide and "real" in purposes) or (
(not s.is_bonafide) and "attack" in purposes
)
(
self.train_real_csv,
self.train_attack_csv,
self.dev_real_csv,
self.dev_attack_csv,
self.eval_real_csv,
self.eval_attack_csv,
) = get_paths()
def get_dict_cache():
cache = dict()
cache["train_real_csv"] = None
cache["train_attack_csv"] = None
cache["dev_real_csv"] = None
cache["dev_attack_csv"] = None
cache["eval_real_csv"] = None
cache["eval_attack_csv"] = None
return cache
self.cache = get_dict_cache()
self.csv_to_sample_loader = csv_to_sample_loader
def _load_samples(self, cache_key, filepointer, is_bonafide):
self.cache[cache_key] = (
self.csv_to_sample_loader(filepointer, is_bonafide)
if self.cache[cache_key] is None
else self.cache[cache_key]
)
return self.cache[cache_key]
results = list(filter(_filter, results))
return results
def fit_samples(self):
return self._load_samples(
"train_real_csv", self.train_real_csv, is_bonafide=True
) + self._load_samples(
"train_attack_csv", self.train_attack_csv, is_bonafide=False
)
return self.samples(groups="train")
def predict_samples(self, group="dev"):
if group == "dev":
return self._load_samples(
"dev_real_csv", self.dev_real_csv, is_bonafide=True
) + self._load_samples(
"dev_attack_csv", self.dev_attack_csv, is_bonafide=False
)
else:
return self._load_samples(
"eval_real_csv", self.eval_real_csv, is_bonafide=True
) + self._load_samples(
"eval_attack_csv", self.eval_attack_csv, is_bonafide=False
)
return self.samples(groups=group)
from .models import ListReader, Client, FileListFile
from .query import FileListPadDatabase
from .driver import Interface
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
"""Says object was actually declared here, and not in the import module.
Fixing sphinx warnings of not being able to find classes, when path is shortened.
Parameters:
*args: An iterable of objects to modify
Resolves `Sphinx referencing issues
<https://github.com/sphinx-doc/sphinx/issues/3048>`
"""
for obj in args:
obj.__module__ = __name__
__appropriate__(
ListReader,
Client,
FileListFile,
FileListPadDatabase,
Interface,
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Laurent El Shafey <laurent.el-shafey@idiap.ch>
#
# Copyright (C) 2011-2013 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Commands the PAD Filelists database can respond to.
"""
import os
import sys
from bob.db.base.driver import Interface as BaseInterface
def clients(args):
"""Dumps lists of client IDs based on your criteria"""
from .query import FileListPadDatabase
db = FileListPadDatabase(args.list_directory, 'pad_filelist')
client_ids = db.client_ids(
protocol=args.protocol,
groups = args.group
)
output = sys.stdout
if args.selftest:
from bob.db.base.utils import null
output = null()
for client in client_ids:
output.write('%s\n' % client)
return 0
def dumplist(args):
"""Dumps lists of files based on your criteria"""
from .query import FileListPadDatabase
db = FileListPadDatabase(args.list_directory, 'pad_filelist')
file_objects = db.objects(
purposes=args.purpose,
groups=args.group,
protocol=args.protocol
)
output = sys.stdout
if args.selftest:
from bob.db.base.utils import null
output = null()
for file_obj in file_objects:
output.write('%s\n' % file_obj.make_path(directory=args.directory, extension=args.extension))
return 0
def checkfiles(args):
"""Checks existence of files based on your criteria"""
from .query import FileListPadDatabase
db = FileListPadDatabase(args.list_directory, 'pad_filelist')
file_objects = db.objects(protocol=args.protocol)
# go through all files, check if they are available on the filesystem
good = []
bad = []
for file_obj in file_objects:
if os.path.exists(file_obj.make_path(args.directory, args.extension)):
good.append(file_obj)
else:
bad.append(file_obj)
# report
output = sys.stdout
if args.selftest:
from bob.db.base.utils import null
output = null()
if bad:
for file_obj in bad:
output.write('Cannot find file "%s"\n' % file_obj.make_path(args.directory, args.extension))
output.write('%d files (out of %d) were not found at "%s"\n' % (len(bad), len(file_objects), args.directory))
return 0
class Interface(BaseInterface):
def name(self):
return 'pad_filelist'
def version(self):
import pkg_resources # part of setuptools
return pkg_resources.require('bob.pad.base')[0].version
def files(self):
return ()
def type(self):
return 'text'
def add_commands(self, parser):
from . import __doc__ as docs
subparsers = self.setup_parser(parser,
"Presentation Attack Detection File Lists database", docs)
import argparse
# the "clients" action
parser = subparsers.add_parser('clients', help=dumplist.__doc__)
parser.add_argument('-l', '--list-directory', required=True,
help="The directory which contains the file lists.")
parser.add_argument('-g', '--group',
help="if given, this value will limit the output files to those belonging to a "
"particular group.",
choices=('dev', 'eval', 'train', ''))
parser.add_argument('-p', '--protocol', default=None,
help="If set, the protocol is appended to the directory that contains the file lists.")
parser.add_argument('--self-test', dest="selftest", action='store_true', help=argparse.SUPPRESS)
parser.set_defaults(func=clients) # action
# the "dumplist" action
parser = subparsers.add_parser('dumplist', help=dumplist.__doc__)
parser.add_argument('-l', '--list-directory', required=True,
help="The directory which contains the file lists.")
parser.add_argument('-d', '--directory', default='',
help="if given, this path will be prepended to every entry returned.")
parser.add_argument('-e', '--extension', default='',
help="if given, this extension will be appended to every entry returned.")
parser.add_argument('-u', '--purpose',
help="if given, this value will limit the output files to those designed "
"for the given purposes.",
choices=('real', 'attack', ''))
parser.add_argument('-g', '--group',
help="if given, this value will limit the output files to those belonging to a "
"particular group.",
choices=('dev', 'eval', 'train', ''))
parser.add_argument('-p', '--protocol', default=None,
help="If set, the protocol is appended to the directory that contains the file lists.")
parser.add_argument('--self-test', dest="selftest", action='store_true', help=argparse.SUPPRESS)
parser.set_defaults(func=dumplist) # action
# the "checkfiles" action
parser = subparsers.add_parser('checkfiles', help=checkfiles.__doc__)
parser.add_argument('-l', '--list-directory', required=True,
help="The directory which contains the file lists.")
parser.add_argument('-d', '--directory', dest="directory", default='',
help="if given, this path will be prepended to every entry returned.")
parser.add_argument('-e', '--extension', dest="extension", default='',
help="if given, this extension will be appended to every entry returned.")
parser.add_argument('-p', '--protocol', default=None,
help="If set, the protocol is appended to the directory that contains the file lists.")
parser.add_argument('--self-test', dest="selftest", action='store_true', help=argparse.SUPPRESS)
parser.set_defaults(func=checkfiles) # action
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :