Commit fbc33ead authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV
Browse files

Merge branch 'flat' into 'master'

Add an option to return a flat list

See merge request !23
parents 45fed636 5409920b
Pipeline #14450 passed with stages
in 7 minutes and 56 seconds
...@@ -130,24 +130,29 @@ class PadDatabase(BioDatabase): ...@@ -130,24 +130,29 @@ class PadDatabase(BioDatabase):
######### Methods to provide common functionality ############### ######### Methods to provide common functionality ###############
################################################################# #################################################################
def all_files(self, groups=('train', 'dev', 'eval')): def all_files(self, groups=('train', 'dev', 'eval'), flat=False):
"""all_files(groups=('train', 'dev', 'eval')) -> files """Returns all files of the database, respecting the current protocol.
The files can be limited using the ``all_files_options`` in the
Returns all files of the database, respecting the current protocol. constructor.
The files can be limited using the ``all_files_options`` in the constructor.
Parameters
**Parameters:** ----------
groups : str or tuple or None
groups : some of ``('train', 'dev', 'eval')`` or ``None`` The groups to get the data for. it should be some of ``('train',
The groups to get the data for. 'dev', 'eval')`` or ``None``
**Returns:** flat : bool
if True, it will merge the real and attack files into one list.
Returns
-------
files : [:py:class:`bob.pad.base.database.PadFile`] files : [:py:class:`bob.pad.base.database.PadFile`]
The sorted and unique list of all files of the database. The sorted and unique list of all files of the database.
""" """
realset = self.sort(self.objects(protocol=self.protocol, groups=groups, purposes='real', **self.all_files_options)) realset = self.sort(self.objects(protocol=self.protocol, groups=groups, purposes='real', **self.all_files_options))
attackset = self.sort(self.objects(protocol=self.protocol, groups=groups, purposes='attack', **self.all_files_options)) attackset = self.sort(self.objects(protocol=self.protocol, groups=groups, purposes='attack', **self.all_files_options))
if flat:
return realset + attackset
return [realset, attackset] return [realset, attackset]
def training_files(self, step=None, arrange_by_client=False): def training_files(self, step=None, arrange_by_client=False):
......
...@@ -43,6 +43,10 @@ class DummyDatabaseSqlTest(unittest.TestCase): ...@@ -43,6 +43,10 @@ class DummyDatabaseSqlTest(unittest.TestCase):
check_file(db.training_files(), 2) check_file(db.training_files(), 2)
check_file(db.files([1])) check_file(db.files([1]))
check_file(db.reverse(["test/path"])) check_file(db.reverse(["test/path"]))
# check if flat returns flat files
assert len(db.all_files(flat=True)) == 2, db.all_files(flat=True)
check_file(db.all_files(flat=True)[0:1])
check_file(db.all_files(flat=True)[1:2])
file = db.objects()[0] file = db.objects()[0]
assert db.original_file_name(file) == "original/directory/test/path.orig" assert db.original_file_name(file) == "original/directory/test/path.orig"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment