diff --git a/bob/pad/base/database/database.py b/bob/pad/base/database/database.py index d05d3e40822a296538a99b37001ffc580e436d1b..807376baea6f0a64acc5241de5750fa0a859821b 100644 --- a/bob/pad/base/database/database.py +++ b/bob/pad/base/database/database.py @@ -130,24 +130,29 @@ class PadDatabase(BioDatabase): ######### Methods to provide common functionality ############### ################################################################# - def all_files(self, groups=('train', 'dev', 'eval')): - """all_files(groups=('train', 'dev', 'eval')) -> files - - Returns all files of the database, respecting the current protocol. - The files can be limited using the ``all_files_options`` in the constructor. - - **Parameters:** - - groups : some of ``('train', 'dev', 'eval')`` or ``None`` - The groups to get the data for. - - **Returns:** - + def all_files(self, groups=('train', 'dev', 'eval'), flat=False): + """Returns all files of the database, respecting the current protocol. + The files can be limited using the ``all_files_options`` in the + constructor. + + Parameters + ---------- + groups : str or tuple or None + The groups to get the data for. it should be some of ``('train', + 'dev', 'eval')`` or ``None`` + + flat : bool + if True, it will merge the real and attack files into one list. + + Returns + ------- files : [:py:class:`bob.pad.base.database.PadFile`] - The sorted and unique list of all files of the database. + The sorted and unique list of all files of the database. """ realset = self.sort(self.objects(protocol=self.protocol, groups=groups, purposes='real', **self.all_files_options)) attackset = self.sort(self.objects(protocol=self.protocol, groups=groups, purposes='attack', **self.all_files_options)) + if flat: + return realset + attackset return [realset, attackset] def training_files(self, step=None, arrange_by_client=False): diff --git a/bob/pad/base/test/test_databases.py b/bob/pad/base/test/test_databases.py index 3653eb5a9ccdfbf58670cc9641de32e250ae662d..703c98703909b9f8ba2c3b35e9b319b1bbfb25f6 100644 --- a/bob/pad/base/test/test_databases.py +++ b/bob/pad/base/test/test_databases.py @@ -43,6 +43,10 @@ class DummyDatabaseSqlTest(unittest.TestCase): check_file(db.training_files(), 2) check_file(db.files([1])) check_file(db.reverse(["test/path"])) + # check if flat returns flat files + assert len(db.all_files(flat=True)) == 2, db.all_files(flat=True) + check_file(db.all_files(flat=True)[0:1]) + check_file(db.all_files(flat=True)[1:2]) file = db.objects()[0] assert db.original_file_name(file) == "original/directory/test/path.orig"