Commit fbc33ead authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV

Merge branch 'flat' into 'master'

Add an option to return a flat list

See merge request !23
parents 45fed636 5409920b
Pipeline #14450 passed with stages
in 7 minutes and 56 seconds
......@@ -130,24 +130,29 @@ class PadDatabase(BioDatabase):
######### Methods to provide common functionality ###############
#################################################################
def all_files(self, groups=('train', 'dev', 'eval')):
"""all_files(groups=('train', 'dev', 'eval')) -> files
Returns all files of the database, respecting the current protocol.
The files can be limited using the ``all_files_options`` in the constructor.
**Parameters:**
groups : some of ``('train', 'dev', 'eval')`` or ``None``
The groups to get the data for.
**Returns:**
def all_files(self, groups=('train', 'dev', 'eval'), flat=False):
"""Returns all files of the database, respecting the current protocol.
The files can be limited using the ``all_files_options`` in the
constructor.
Parameters
----------
groups : str or tuple or None
The groups to get the data for. it should be some of ``('train',
'dev', 'eval')`` or ``None``
flat : bool
if True, it will merge the real and attack files into one list.
Returns
-------
files : [:py:class:`bob.pad.base.database.PadFile`]
The sorted and unique list of all files of the database.
The sorted and unique list of all files of the database.
"""
realset = self.sort(self.objects(protocol=self.protocol, groups=groups, purposes='real', **self.all_files_options))
attackset = self.sort(self.objects(protocol=self.protocol, groups=groups, purposes='attack', **self.all_files_options))
if flat:
return realset + attackset
return [realset, attackset]
def training_files(self, step=None, arrange_by_client=False):
......
......@@ -43,6 +43,10 @@ class DummyDatabaseSqlTest(unittest.TestCase):
check_file(db.training_files(), 2)
check_file(db.files([1]))
check_file(db.reverse(["test/path"]))
# check if flat returns flat files
assert len(db.all_files(flat=True)) == 2, db.all_files(flat=True)
check_file(db.all_files(flat=True)[0:1])
check_file(db.all_files(flat=True)[1:2])
file = db.objects()[0]
assert db.original_file_name(file) == "original/directory/test/path.orig"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment