Commit c1227609 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Implement a better file csv interface

parent 932cdf39
Pipeline #48601 passed with stage
in 7 minutes and 2 seconds
from .file import PadFile
from .database import PadDatabase
from .filelist.query import FileListPadDatabase
from .filelist.models import Client
from .csv_dataset import FileListPadDatabase
from .PadBioFileDB import HighBioDatabase, HighPadDatabase
from .csv_dataset import CSVPADDataset, CSVToSampleLoader, LSTToSampleLoader
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
......@@ -25,11 +23,7 @@ __appropriate__(
PadFile,
PadDatabase,
FileListPadDatabase,
Client,
HighBioDatabase,
HighPadDatabase,
CSVPADDataset,
CSVToSampleLoader,
LSTToSampleLoader,
)
__all__ = [_ for _ in dir() if not _.startswith("_")]
......@@ -2,324 +2,73 @@
# vim: set fileencoding=utf-8 :
from bob.db.base.utils import check_parameters_for_validity
from bob.pad.base.pipelines.vanilla_pad.abstract_classes import Database
import csv
from bob.pipelines.datasets import FileListDatabase, CSVToSamples
from bob.pipelines.datasets.sample_loaders import CSVBaseSampleLoader
from bob.extension.download import search_file
class CSVToPADSamples(CSVToSamples):
"""Converts a csv file to a list of PAD samples"""
from bob.pipelines import DelayedSample
import bob.io.base
import os
import functools
class CSVToSampleLoader(CSVBaseSampleLoader):
"""
Simple mechanism that converts the lines of a CSV file to
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
Each CSV line of a PAD datasets should have at least a PATH and a SUBJECT id like
in the example below:
.. code-block:: text
PATH,SUBJECT
path_1,subject_1
"""
def check_header(self, header):
"""
A header should have at least "subject" AND "PATH"
"""
header = [h.lower() for h in header]
if not "subject" in header:
raise ValueError("The field `subject` is not available in your dataset.")
if not "path" in header:
raise ValueError("The field `path` is not available in your dataset.")
def __call__(self, f, is_bonafide=True):
f.seek(0)
reader = csv.reader(f)
header = next(reader)
self.check_header(header)
return [
self.convert_row_to_sample(row, header, is_bonafide=is_bonafide)
for row in reader
]
def convert_row_to_sample(self, row, header=None, is_bonafide=True):
path = str(row[0])
subject = str(row[1])
kwargs = dict([[str(h).lower(), r] for h, r in zip(header[2:], row[2:])])
if self.metadata_loader is not None:
metadata = self.metadata_loader(row, header=header)
kwargs.update(metadata)
return DelayedSample(
functools.partial(
self.data_loader,
os.path.join(self.dataset_original_directory, path + self.extension),
),
key=path,
subject=subject,
is_bonafide=is_bonafide,
**kwargs,
)
class LSTToSampleLoader(CSVBaseSampleLoader):
"""
Simple mechanism that converts the lines of a LST file to
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
"""
def __call__(self, f, is_bonafide=True):
f.seek(0)
reader = csv.reader(f, delimiter=" ")
samples = []
for row in reader:
if row[0][0] == "#":
continue
samples.append(self.convert_row_to_sample(row, is_bonafide=is_bonafide))
return samples
def convert_row_to_sample(self, row, header=None, is_bonafide=True):
path = str(row[0])
subject = str(row[1])
attack_type = None
if len(row) == 3:
attack_type = str(row[2])
kwargs = dict()
if self.metadata_loader is not None:
metadata = self.metadata_loader(row, header=header)
kwargs.update(metadata)
return DelayedSample(
functools.partial(
self.data_loader,
os.path.join(self.dataset_original_directory, path + self.extension),
),
key=path,
subject=subject,
is_bonafide=is_bonafide,
attack_type=attack_type,
**kwargs,
)
class CSVPADDataset(Database):
"""
Generic filelist dataset for PAD experiments.
To create a new dataset, you need to provide a directory structure similar to the one below:
.. code-block:: text
my_dataset/
my_dataset/my_protocol/train/for_real.csv
my_dataset/my_protocol/train/for_attack.csv
my_dataset/my_protocol/dev/for_real.csv
my_dataset/my_protocol/dev/for_attack.csv
my_dataset/my_protocol/eval/for_real.csv
my_dataset/my_protocol/eval/for_attack.csv
These csv files should contain in each row i-) the path to raw data and
ii-) and an identifier to the subject in the image (subject).
The structure of each CSV file should be as below:
.. code-block:: text
PATH,SUBJECT
path_1,subject_1
path_2,subject_2
path_i,subject_j
...
You might want to ship metadata within your Samples (e.g gender, age, annotations, ...)
To do so is simple, just do as below:
.. code-block:: text
PATH,SUBJECT,TYPE_OF_ATTACK,GENDER,AGE
path_1,subject_1,A,B,C
path_2,subject_2,A,B,1
path_i,subject_j,2,3,4
...
The files `my_dataset/my_protocol/eval/for_real.csv` and `my_dataset/my_protocol/eval/for_attack.csv`
are optional and it is used in case a protocol contains data for evaluation.
Finally, the content of the files `my_dataset/my_protocol/train/for_real.csv` and `my_dataset/my_protocol/train/for_attack.csv` are used in the case a protocol
contains data for training.
Parameters
----------
dataset_path: str
Absolute path or a tarball of the dataset protocol description.
protocol_na,e: str
The name of the protocol
def __iter__(self):
for sample in super().__iter__():
if not hasattr(sample, "subject"):
raise RuntimeError(
"PAD samples should contain a `subject` attribute which "
"reveals the identifies the person from whom the sample is created."
)
if not hasattr(sample, "attack_type"):
raise RuntimeError(
"PAD samples should contain a `attack_type` attribute which "
"should be '' for bona fide samples and something like "
"print, replay, mask, etc. for attacks. This attribute is "
"considered the PAI type of each attack is used to compute APCER."
)
if sample.attack_type == "":
sample.attack_type = None
sample.is_bonafide = sample.attack_type is None
if not hasattr(sample, "key"):
sample.key = sample.filename
yield sample
csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
Base class that whose objective is to generate :any:`bob.pipelines.Sample`
and/or :any:`bob.pipelines.SampleSet` from csv rows
"""
class FileListPadDatabase(Database, FileListDatabase):
"""A PAD database interface from CSV files."""
def __init__(
self,
dataset_protocol_path,
protocol_name,
csv_to_sample_loader=CSVToSampleLoader(
data_loader=bob.io.base.load,
metadata_loader=None,
dataset_original_directory="",
extension="",
),
dataset_protocols_path,
protocol,
transformer=None,
**kwargs,
):
self.dataset_protocol_path = dataset_protocol_path
self.protocol_name = protocol_name
def get_paths():
if not os.path.exists(dataset_protocol_path):
raise ValueError(f"The path `{dataset_protocol_path}` was not found")
# Here we are handling the legacy
train_real_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "train", "for_real.lst"),
os.path.join(protocol_name, "train", "for_real.csv"),
],
)
train_attack_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "train", "for_attack.lst"),
os.path.join(protocol_name, "train", "for_attack.csv"),
],
)
dev_real_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "dev", "for_real.lst"),
os.path.join(protocol_name, "dev", "for_real.csv"),
],
)
dev_attack_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "dev", "for_attack.lst"),
os.path.join(protocol_name, "dev", "for_attack.csv"),
],
)
eval_real_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "eval", "for_real.lst"),
os.path.join(protocol_name, "eval", "for_real.csv"),
],
)
eval_attack_csv = search_file(
dataset_protocol_path,
[
os.path.join(protocol_name, "eval", "for_attack.lst"),
os.path.join(protocol_name, "eval", "for_attack.csv"),
],
)
super().__init__(
dataset_protocols_path=dataset_protocols_path,
protocol=protocol,
reader_cls=CSVToPADSamples,
transformer=transformer,
**kwargs,
)
# The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`
def purposes(self):
return ("real", "attack")
# Dev
if dev_real_csv is None:
raise ValueError(
f"The file `{dev_real_csv}` is required and it was not found"
)
if dev_attack_csv is None:
raise ValueError(
f"The file `{dev_attack_csv}` is required and it was not found"
)
def samples(self, groups=None, purposes=None):
results = super().samples(groups=groups)
purposes = check_parameters_for_validity(
purposes, "purposes", self.purposes(), self.purposes()
)
return (
train_real_csv,
train_attack_csv,
dev_real_csv,
dev_attack_csv,
eval_real_csv,
eval_attack_csv,
def _filter(s):
return (s.is_bonafide and "real" in purposes) or (
(not s.is_bonafide) and "attack" in purposes
)
(
self.train_real_csv,
self.train_attack_csv,
self.dev_real_csv,
self.dev_attack_csv,
self.eval_real_csv,
self.eval_attack_csv,
) = get_paths()
def get_dict_cache():
cache = dict()
cache["train_real_csv"] = None
cache["train_attack_csv"] = None
cache["dev_real_csv"] = None
cache["dev_attack_csv"] = None
cache["eval_real_csv"] = None
cache["eval_attack_csv"] = None
return cache
self.cache = get_dict_cache()
self.csv_to_sample_loader = csv_to_sample_loader
def _load_samples(self, cache_key, filepointer, is_bonafide):
self.cache[cache_key] = (
self.csv_to_sample_loader(filepointer, is_bonafide)
if self.cache[cache_key] is None
else self.cache[cache_key]
)
return self.cache[cache_key]
results = list(filter(_filter, results))
return results
def fit_samples(self):
return self._load_samples(
"train_real_csv", self.train_real_csv, is_bonafide=True
) + self._load_samples(
"train_attack_csv", self.train_attack_csv, is_bonafide=False
)
return self.samples(groups="train")
def predict_samples(self, group="dev"):
if group == "dev":
return self._load_samples(
"dev_real_csv", self.dev_real_csv, is_bonafide=True
) + self._load_samples(
"dev_attack_csv", self.dev_attack_csv, is_bonafide=False
)
else:
return self._load_samples(
"eval_real_csv", self.eval_real_csv, is_bonafide=True
) + self._load_samples(
"eval_attack_csv", self.eval_attack_csv, is_bonafide=False
)
return self.samples(groups=group)
from .models import ListReader, Client, FileListFile
from .query import FileListPadDatabase
from .driver import Interface
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
"""Says object was actually declared here, and not in the import module.
Fixing sphinx warnings of not being able to find classes, when path is shortened.
Parameters:
*args: An iterable of objects to modify
Resolves `Sphinx referencing issues
<https://github.com/sphinx-doc/sphinx/issues/3048>`
"""
for obj in args:
obj.__module__ = __name__
__appropriate__(
ListReader,
Client,
FileListFile,
FileListPadDatabase,
Interface,
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Laurent El Shafey <laurent.el-shafey@idiap.ch>
#
# Copyright (C) 2011-2013 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Commands the PAD Filelists database can respond to.
"""
import os
import sys
from bob.db.base.driver import Interface as BaseInterface
def clients(args):
"""Dumps lists of client IDs based on your criteria"""
from .query import FileListPadDatabase
db = FileListPadDatabase(args.list_directory, 'pad_filelist')
client_ids = db.client_ids(
protocol=args.protocol,
groups = args.group
)
output = sys.stdout
if args.selftest:
from bob.db.base.utils import null
output = null()
for client in client_ids:
output.write('%s\n' % client)
return 0
def dumplist(args):
"""Dumps lists of files based on your criteria"""
from .query import FileListPadDatabase
db = FileListPadDatabase(args.list_directory, 'pad_filelist')
file_objects = db.objects(
purposes=args.purpose,
groups=args.group,
protocol=args.protocol
)
output = sys.stdout
if args.selftest:
from bob.db.base.utils import null
output = null()
for file_obj in file_objects:
output.write('%s\n' % file_obj.make_path(directory=args.directory, extension=args.extension))
return 0
def checkfiles(args):
"""Checks existence of files based on your criteria"""
from .query import FileListPadDatabase
db = FileListPadDatabase(args.list_directory, 'pad_filelist')
file_objects = db.objects(protocol=args.protocol)
# go through all files, check if they are available on the filesystem
good = []
bad = []
for file_obj in file_objects:
if os.path.exists(file_obj.make_path(args.directory, args.extension)):
good.append(file_obj)
else:
bad.append(file_obj)
# report
output = sys.stdout
if args.selftest:
from bob.db.base.utils import null
output = null()
if bad:
for file_obj in bad:
output.write('Cannot find file "%s"\n' % file_obj.make_path(args.directory, args.extension))
output.write('%d files (out of %d) were not found at "%s"\n' % (len(bad), len(file_objects), args.directory))
return 0
class Interface(BaseInterface):
def name(self):
return 'pad_filelist'
def version(self):
import pkg_resources # part of setuptools
return pkg_resources.require('bob.pad.base')[0].version
def files(self):
return ()
def type(self):
return 'text'
def add_commands(self, parser):
from . import __doc__ as docs
subparsers = self.setup_parser(parser,
"Presentation Attack Detection File Lists database", docs)
import argparse
# the "clients" action
parser = subparsers.add_parser('clients', help=dumplist.__doc__)
parser.add_argument('-l', '--list-directory', required=True,
help="The directory which contains the file lists.")
parser.add_argument('-g', '--group',
help="if given, this value will limit the output files to those belonging to a "
"particular group.",
choices=('dev', 'eval', 'train', ''))
parser.add_argument('-p', '--protocol', default=None,
help="If set, the protocol is appended to the directory that contains the file lists.")
parser.add_argument('--self-test', dest="selftest", action='store_true', help=argparse.SUPPRESS)
parser.set_defaults(func=clients) # action
# the "dumplist" action
parser = subparsers.add_parser('dumplist', help=dumplist.__doc__)
parser.add_argument('-l', '--list-directory', required=True,
help="The directory which contains the file lists.")
parser.add_argument('-d', '--directory', default='',
help="if given, this path will be prepended to every entry returned.")
parser.add_argument('-e', '--extension', default='',
help="if given, this extension will be appended to every entry returned.")
parser.add_argument('-u', '--purpose',
help="if given, this value will limit the output files to those designed "
"for the given purposes.",
choices=('real', 'attack', ''))
parser.add_argument('-g', '--group',
help="if given, this value will limit the output files to those belonging to a "
"particular group.",
choices=('dev', 'eval', 'train', ''))
parser.add_argument('-p', '--protocol', default=None,
help="If set, the protocol is appended to the directory that contains the file lists.")
parser.add_argument('--self-test', dest="selftest", action='store_true', help=argparse.SUPPRESS)
parser.set_defaults(func=dumplist) # action
# the "checkfiles" action
parser = subparsers.add_parser('checkfiles', help=checkfiles.__doc__)
parser.add_argument('-l', '--list-directory', required=True,
help="The directory which contains the file lists.")
parser.add_argument('-d', '--directory', dest="directory", default='',
help="if given, this path will be prepended to every entry returned.")
parser.add_argument('-e', '--extension', dest="extension", default='',
help="if given, this extension will be appended to every entry returned.")
parser.add_argument('-p', '--protocol', default=None,
help="If set, the protocol is appended to the directory that contains the file lists.")
parser.add_argument('--self-test', dest="selftest", action='store_true', help=argparse.SUPPRESS)
parser.set_defaults(func=checkfiles) # action
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Pavel Korshunov <pavel.korshunov@idiap.ch>
# @date: Thu Nov 17 16:09:22 CET 2016
#
# Copyright (C) 2011-2013 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
This file defines simple Client and File interfaces that are comparable with other bob.db databases.
"""
import os
import fileinput
import re
from bob.pad.base.database import PadFile
class Client(object):
"""
The clients of this database contain ONLY client ids. Nothing special.
"""
def __init__(self, client_id):
self.id = client_id
"""The ID of the client, which is stored as a :py:class:`str` object."""
class FileListFile(PadFile):
"""
Initialize the File object with the minimum required data.
**Parameters**
path : str
The path of this file, relative to the basic directory.
Please do not specify any file extensions.
This path will be used as an underlying file_id, as it is assumed to be unique
client_id : various type
The id of the client, this file belongs to.
The type of it is dependent on your implementation.
If you use an SQL database, this should be an SQL type like Integer or String.
"""
def __init__(self, file_name, client_id, attack_type=None):
super(FileListFile, self).__init__(client_id=client_id, path=file_name, attack_type=attack_type, file_id=file_name)
#############################################################################
### internal access functions for the file lists; do not export!
#############################################################################
class ListReader(object):
def __init__(self, store_lists):
self.m_read_lists = {}
self.m_store_lists = store_lists
def _read_multi_column_list(self, list_file):
rows = []
if not os.path.isfile(list_file):
raise RuntimeError('File %s does not exist.' % (list_file,))
try:
for line in fileinput.input(list_file):
parsed_line = re.findall('[\w/(-.)]+', line)
if len(parsed_line):
# perform some sanity checks
if len(parsed_line) not in (2, 3):
raise IOError("The read line '%s' from file '%s' could not be parsed successfully!" %
(line.rstrip(), list_file))
if len(rows) and len(rows[0]) != len(parsed_line):
raise IOError("The parsed line '%s' from file '%s' has a different number of elements "
"than the first parsed line '%s'!" % (parsed_line, list_file, rows[0]))
# append the read line
rows.append(parsed_line)
fileinput.close()
except IOError as e:
raise RuntimeError("Error reading the file '%s' : '%s'." % (list_file, e))
# return the read list as a vector of columns
return rows
def _read_column_list(self, list_file, column_count):
# read the list
rows = self._read_multi_column_list(list_file)
# extract the file from the first two columns
file_list = []
for row in rows:
if column_count == 2:
assert len(row) == 2
# we expect: filename client_id
file_list.append(FileListFile(file_name=row[0], client_id=row[1]))
elif column_count == 3:
assert len(row) == 3