Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • bob/bob.pad.base
1 result
Show changes
Commits on Source (9)
......@@ -9,7 +9,7 @@ import os
from bob.bio.base import utils
class Algorithm:
class Algorithm(object):
"""This is the base class for all anti-spoofing algorithms.
It defines the minimum requirements for all derived algorithm classes.
......
......@@ -130,24 +130,29 @@ class PadDatabase(BioDatabase):
######### Methods to provide common functionality ###############
#################################################################
def all_files(self, groups=('train', 'dev', 'eval')):
"""all_files(groups=('train', 'dev', 'eval')) -> files
Returns all files of the database, respecting the current protocol.
The files can be limited using the ``all_files_options`` in the constructor.
**Parameters:**
groups : some of ``('train', 'dev', 'eval')`` or ``None``
The groups to get the data for.
**Returns:**
def all_files(self, groups=('train', 'dev', 'eval'), flat=False):
"""Returns all files of the database, respecting the current protocol.
The files can be limited using the ``all_files_options`` in the
constructor.
Parameters
----------
groups : str or tuple or None
The groups to get the data for. it should be some of ``('train',
'dev', 'eval')`` or ``None``
flat : bool
if True, it will merge the real and attack files into one list.
Returns
-------
files : [:py:class:`bob.pad.base.database.PadFile`]
The sorted and unique list of all files of the database.
The sorted and unique list of all files of the database.
"""
realset = self.sort(self.objects(protocol=self.protocol, groups=groups, purposes='real', **self.all_files_options))
attackset = self.sort(self.objects(protocol=self.protocol, groups=groups, purposes='attack', **self.all_files_options))
if flat:
return realset + attackset
return [realset, attackset]
def training_files(self, step=None, arrange_by_client=False):
......
......@@ -104,12 +104,7 @@ class FileListPadDatabase(PadDatabase, FileListBioDatabase):
keep_read_lists_in_memory=True,
**kwargs
):
"""We call PadDatabase.__init__() instead of super() because of we do not want
bob.bio.base.database.FileListBioDatabase.__init__() to be called by super().
bob.bio.base.database.FileListBioDatabase depends on bob.bio.base.database.ZTBioDatabase, which would
throw an exception, since we do not implement here methods for ZT-based metric."""
PadDatabase.__init__(self,
super(FileListPadDatabase, self).__init__(
name=name,
protocol=protocol,
original_directory=original_directory,
......@@ -118,13 +113,15 @@ class FileListPadDatabase(PadDatabase, FileListBioDatabase):
annotation_extension=annotation_extension,
annotation_type=annotation_type,
filelists_directory=filelists_directory,
# extra args for pretty printing
train_sub_directory=train_subdir,
dev_sub_directory=dev_subdir,
eval_sub_directory=eval_subdir,
real_filename=real_filename,
attack_filename=attack_filename,
**kwargs)
# extra args for pretty printing
self._kwargs.update(dict(
train_sub_directory=train_subdir,
real_filename=real_filename,
attack_filename=attack_filename,
))
self.pad_file_class = pad_file_class
self.list_readers = {}
......
......@@ -28,14 +28,16 @@ class TestFileSql (Base, bob.pad.base.database.PadFile):
path = Column(String(100), unique=True)
def __init__(self):
bob.pad.base.database.PadFile.__init__(self, client_id=5, path="test/path")
bob.pad.base.database.PadFile.__init__(
self, client_id=5, path="test/path")
def create_database():
if os.path.exists(dbfile):
os.remove(dbfile)
import bob.db.base.utils
engine = bob.db.base.utils.create_engine_try_nolock('sqlite', dbfile, echo=True)
engine = bob.db.base.utils.create_engine_try_nolock(
'sqlite', dbfile, echo=True)
Base.metadata.create_all(engine)
session = bob.db.base.utils.session('sqlite', dbfile, echo=True)
session.add(TestFileSql())
......@@ -48,9 +50,11 @@ def create_database():
class TestDatabaseSql (bob.pad.base.database.PadDatabase, bob.db.base.SQLiteBaseDatabase):
def __init__(self):
bob.pad.base.database.PadDatabase.__init__(self, 'pad_test',
original_directory="original/directory", original_extension=".orig")
bob.db.base.SQLiteBaseDatabase.__init__(self, dbfile, TestFileSql)
super(TestDatabaseSql, self).__init__(
name='pad_test',
original_directory="original/directory",
original_extension=".orig", sqlite_file=dbfile,
file_class=TestFileSql)
def groups(self, protocol=None):
return ['group']
......@@ -61,4 +65,5 @@ class TestDatabaseSql (bob.pad.base.database.PadDatabase, bob.db.base.SQLiteBase
def annotations(self, file):
return None
database = TestDatabaseSql()
\ No newline at end of file
database = TestDatabaseSql()
......@@ -43,6 +43,10 @@ class DummyDatabaseSqlTest(unittest.TestCase):
check_file(db.training_files(), 2)
check_file(db.files([1]))
check_file(db.reverse(["test/path"]))
# check if flat returns flat files
assert len(db.all_files(flat=True)) == 2, db.all_files(flat=True)
check_file(db.all_files(flat=True)[0:1])
check_file(db.all_files(flat=True)[1:2])
file = db.objects()[0]
assert db.original_file_name(file) == "original/directory/test/path.orig"
......