Commit a634e545 authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV

moved pad DB interface here, tests are passing

parent a215ab22
#from .utils import File, FileSet
# from bob.bio.base.database.Database import Database
from .DatabaseBobSpoof import DatabaseBobSpoof
from .database import PadDatabase
from .file import PadFile
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Manuel Guenther <Manuel.Guenther@idiap.ch>
# @author: Pavel Korshunov <pavel.korshunov@idiap.ch>
# @date: Tue May 17 12:09:22 CET 2016
#
import abc
import bob.bio.base.database
class PadDatabase(bob.bio.base.database.BioDatabase):
def __init__(
self,
name,
all_files_options={}, # additional options for the database query that can be used to extract all files
check_original_files_for_existence=False,
original_directory=None,
original_extension=None,
protocol='Default',
**kwargs # The rest of the default parameters of the base class
):
"""This class represents the basic API for database access.
Please use this class as a base class for your database access classes.
Do not forget to call the constructor of this base class in your derived class.
**Parameters:**
name : str
A unique name for the database.
all_files_options : dict
Dictionary of options passed to the second-level database query when retrieving all data.
check_original_files_for_existence : bool
Enables to test for the original data files when querying the database.
original_directory : str
The directory where the original data of the database are stored.
original_extension : str
The file name extension of the original data.
protocol : str or ``None``
The name of the protocol that defines the default experimental setup for this database.
kwargs : ``key=value`` pairs
The arguments of the :py:class:`bob.bio.base.BioDatabase` base class constructor.
"""
super(PadDatabase, self).__init__(name=name, all_files_options=all_files_options, check_original_files_for_existence=check_original_files_for_existence, original_directory=original_directory, original_extension=original_extension, protocol=protocol, **kwargs)
def original_file_names(self, files):
"""original_file_names(files) -> paths
Returns the full paths of the real and attack data of the given PadFile objects.
**Parameters:**
files : [[:py:class:`bob.pad.db.PadFile`], [:py:class:`bob.pad.db.PadFile`]
The list of lists ([real, attack]) of file object to retrieve the original data file names for.
**Returns:**
paths : [str] or [[str]]
The paths extracted for the concatenated real+attack files, in the preserved order.
"""
assert self.original_directory is not None
assert self.original_extension is not None
realfiles = files[0]
attackfiles = files[1]
realpaths = [file.make_path(directory=self.original_directory, extension=self.original_extension) for file in
realfiles]
attackpaths = [file.make_path(directory=self.original_directory, extension=self.original_extension) for file in
attackfiles]
return realpaths + attackpaths
def model_ids_with_protocol(self, groups=None, protocol=None, **kwargs):
"""model_ids_with_protocol(groups = None, protocol = None, **kwargs) -> ids
Client-based PAD is not implemented.
"""
return []
def annotations(self, file):
"""
Annotations are not supported by PAD interface
"""
return None
@abc.abstractmethod
def objects(self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs):
"""This function returns lists of File objects, which fulfill the given restrictions.
Keyword parameters:
groups : str or [str]
The groups of which the clients should be returned.
Usually, groups are one or more elements of ('train', 'dev', 'eval')
protocol
The protocol for which the clients should be retrieved.
The protocol is dependent on your database.
If you do not have protocols defined, just ignore this field.
purposes : str or [str]
The purposes for which File objects should be retrieved.
Usually it is either 'real' or 'attack'.
model_ids : [various type]
This parameter is not suported in PAD databases yet
"""
raise NotImplementedError("This function must be implemented in your derived class.")
#################################################################
######### Methods to provide common functionality ###############
#################################################################
def all_files(self, groups=('train', 'dev', 'eval')):
"""all_files(groups=('train', 'dev', 'eval')) -> files
Returns all files of the database, respecting the current protocol.
The files can be limited using the ``all_files_options`` in the constructor.
**Parameters:**
groups : some of ``('train', 'dev', 'eval')`` or ``None``
The groups to get the data for.
**Returns:**
files : [:py:class:`File`]
The sorted and unique list of all files of the database.
"""
realset = self.sort(self.objects(protocol=self.protocol, groups=groups, purposes='real', **self.all_files_options))
attackset = self.sort(self.objects(protocol=self.protocol, groups=groups, purposes='attack', **self.all_files_options))
return [realset, attackset]
def training_files(self, step=None, arrange_by_client=False):
"""training_files(step = None, arrange_by_client = False) -> files
Returns all training File objects
This function needs to be implemented in derived class implementations.
**Parameters:**
The parameters are not applicable in this version of anti-spoofing experiments
**Returns:**
files : [:py:class:`File`] or [[:py:class:`File`]]
The (arranged) list of files used for the training.
"""
return self.all_files(groups=('train',))
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Pavel Korshunov <pavel.korshunov@idiap.ch>
# @date: Wed May 18 10:09:22 CET 2016
#
import bob.bio.base.database
class PadFile(bob.bio.base.database.BioFile):
"""A simple base class that defines basic properties of File object for the use in PAD experiments"""
def __init__(self, client_id, path, attack_type=None, file_id=None):
"""**Constructor Documentation**
Initialize the File object with the minimum required data.
Parameters:
attack_type : a string type
In cased of a spoofed data, this parameter should indicate what kind of spoofed attack it is.
The default None value is interpreted that the PadFile is a genuine or real sample.
For client_id, path and file_id, please refer to :py:class:`bob.bio.base.BioFile` constructor
"""
super(PadFile, self).__init__(client_id, path, file_id)
if attack_type is not None:
assert isinstance(attack_type, str)
# just copy the information
self.attack_type = attack_type
"""The attack type of the sample, None if it is a genuine sample."""
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Manuel Guenther <Manuel.Guenther@idiap.ch>
# @author: Pavel Korshunov <pavel.korshunov@idiap.ch>
# @date: Tue May 17 12:09:22 CET 2016
#
import os
import bob.io.base
import bob.io.base.test_utils
import bob.pad.base.database
import bob.db.base
from sqlalchemy import Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
regenerate_database = False
dbfile = bob.io.base.test_utils.datafile("test_db.sql3", "bob.pad.base.test")
Base = declarative_base()
class TestFileSql (Base, bob.pad.base.database.PadFile):
__tablename__ = "file"
id = Column(Integer, primary_key=True)
client_id = Column(Integer, unique=True)
path = Column(String(100), unique=True)
def __init__(self):
bob.pad.base.database.PadFile.__init__(self, client_id=5, path="test/path")
def create_database():
if os.path.exists(dbfile):
os.remove(dbfile)
import bob.db.base.utils
engine = bob.db.base.utils.create_engine_try_nolock('sqlite', dbfile, echo=True)
Base.metadata.create_all(engine)
session = bob.db.base.utils.session('sqlite', dbfile, echo=True)
session.add(TestFileSql())
session.commit()
session.close()
del session
del engine
class TestDatabaseSql (bob.pad.base.database.PadDatabase, bob.db.base.SQLiteDatabase):
def __init__(self):
bob.pad.base.database.PadDatabase.__init__(self, 'pad_test', original_directory="original/directory", original_extension=".orig")
bob.db.base.SQLiteDatabase.__init__(self, dbfile, TestFileSql)
def groups(self, protocol=None):
return ['group']
def objects(self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs):
return list(self.query(TestFileSql))
database = TestDatabaseSql()
\ No newline at end of file
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Manuel Guenther <Manuel.Guenther@idiap.ch>
# @author: Pavel Korshunov <pavel.korshunov@idiap.ch>
# @date: Tue May 17 12:09:22 CET 2016
#
import os
import shutil
import bob.io.base
import bob.io.base.test_utils
import bob.bio.base.database
import bob.pad.base.database
import bob.db.base
import tempfile
from sqlalchemy import Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
regenerate_database = False
dbfile = bob.io.base.test_utils.datafile("test_db.sql3", "bob.pad.base.test")
Base = declarative_base()
class TestFile (Base, bob.pad.base.database.PadFile):
__tablename__ = "file"
id = Column(Integer, primary_key=True)
client_id = Column(Integer, unique=True)
path = Column(String(100), unique=True)
def __init__(self):
bob.pad.base.database.PadFile.__init__(self, client_id=5, path="test/path")
def create_database():
if os.path.exists(dbfile):
os.remove(dbfile)
import bob.db.base.utils
engine = bob.db.base.utils.create_engine_try_nolock('sqlite', dbfile, echo=True)
Base.metadata.create_all(engine)
session = bob.db.base.utils.session('sqlite', dbfile, echo=True)
session.add(TestFile())
session.commit()
session.close()
del session
del engine
class TestDatabase (bob.pad.base.database.PadDatabase, bob.db.base.SQLiteDatabase):
def __init__(self):
bob.pad.base.database.PadDatabase.__init__(self, 'pad_test', original_directory="original/directory", original_extension=".orig")
bob.db.base.SQLiteDatabase.__init__(self, dbfile, TestFile)
def groups(self, protocol=None):
return ['group']
def objects(self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs):
return list(self.query(TestFile))
# def test01_database():
# # check that the database API works
# if regenerate_database:
# create_database()
#
# db = TestDatabase()
#
# def check_file(fs, l=1):
# assert len(fs) == l
# if l == 1:
# f = fs[0]
# else:
# f = fs[0][0]
# assert isinstance(f, TestFile)
# assert f.id == 1
# assert f.client_id == 5
# assert f.path == "test/path"
#
# check_file(db.objects())
# check_file(db.all_files(), 2)
# check_file(db.training_files(), 2)
# check_file(db.files([1]))
# check_file(db.reverse(["test/path"]))
#
# file = db.objects()[0]
# assert db.original_file_name(file) == "original/directory/test/path.orig"
# assert db.file_names([file], "another/directory", ".other")[0] == "another/directory/test/path.other"
# assert db.paths([1], "another/directory", ".other")[0] == "another/directory/test/path.other"
#
# # try file save
# temp_dir = tempfile.mkdtemp(prefix="bob_db_test_")
# data = [1., 2., 3.]
# file.save(data, temp_dir)
# assert os.path.exists(file.make_path(temp_dir, ".hdf5"))
# read_data = bob.io.base.load(file.make_path(temp_dir, ".hdf5"))
# for i in range(3):
# assert data[i] == read_data[i]
# shutil.rmtree(temp_dir)
......@@ -21,11 +21,58 @@
import os
import unittest
import bob.pad.base
from bob.pad.base.test.dummy.database_sql import create_database
import pkg_resources
import tempfile
import shutil
dummy_dir = pkg_resources.resource_filename('bob.pad.base', 'test/dummy')
regenerate_database = False
class DummyDatabaseSqlTest(unittest.TestCase):
def test01_database(self):
# check that the database API works
if regenerate_database:
create_database()
db = bob.pad.base.test.dummy.database_sql.TestDatabaseSql()
def check_file(fs, l=1):
assert len(fs) == l
if l == 1:
f = fs[0]
else:
f = fs[0][0]
assert isinstance(f, bob.pad.base.test.dummy.database_sql.TestFileSql)
assert f.id == 1
assert f.client_id == 5
assert f.path == "test/path"
check_file(db.objects())
check_file(db.all_files(), 2)
check_file(db.training_files(), 2)
check_file(db.files([1]))
check_file(db.reverse(["test/path"]))
file = db.objects()[0]
assert db.original_file_name(file) == "original/directory/test/path.orig"
assert db.file_names([file], "another/directory", ".other")[0] == "another/directory/test/path.other"
assert db.paths([1], "another/directory", ".other")[0] == "another/directory/test/path.other"
# try file save
temp_dir = tempfile.mkdtemp(prefix="bob_db_test_")
data = [1., 2., 3.]
file.save(data, temp_dir)
assert os.path.exists(file.make_path(temp_dir, ".hdf5"))
read_data = bob.io.base.load(file.make_path(temp_dir, ".hdf5"))
for i in range(3):
assert data[i] == read_data[i]
shutil.rmtree(temp_dir)
class DummyDatabaseTest(unittest.TestCase):
"""Performs various tests on the AVspoof attack database."""
......
......@@ -72,6 +72,7 @@ def _detect(parameters, cur_test_dir, sub_dir, score_types=('dev-real',), scores
assert numpy.allclose(data2check[0][:, 3].astype(float), data2check[1][:, 3].astype(float), 1e-5)
finally:
# print ("empty")
shutil.rmtree(cur_test_dir)
......
; vim: set fileencoding=utf-8 :
; Pavel Korshunov <Pavel.Korshunov@idiap.ch>
; Wed 19 Aug 13:43:22 2015
[buildout]
parts = scripts
eggs = bob.pad.base
gridtk
extensions = bob.buildout
mr.developer
auto-checkout = *
develop = src/bob.db.base
src/bob.bio.base
.
; options for bob.buildout
debug = true
verbose = true
newest = false
[sources]
bob.db.base = git git@gitlab.idiap.ch:bob/bob.db.base.git
bob.bio.base = git git@gitlab.idiap.ch:bob/bob.bio.base.git
[scripts]
recipe = bob.buildout:scripts
dependent-scripts = true
......@@ -113,6 +113,7 @@ setup(
'bob.pad.database': [
'dummy = bob.pad.base.test.dummy.database:database', # for test purposes only
'dummysql = bob.pad.base.test.dummy.database_sql:database', # for test purposes only
],
'bob.pad.preprocessor': [
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment