Commit c49ad044 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Fix issue bob.db.base#16

parent 4b9f2304
Pipeline #9220 passed with stages
in 21 minutes and 52 seconds
......@@ -29,11 +29,13 @@ dummy_data = {'train_real': 1.0, 'train_attack': 2.0,
class TestFile(PadFile):
def __init__(self, path, id):
attack_type = None
if "attack" in path:
attack_type = "attack"
PadFile.__init__(self, client_id=1, path=path, file_id=id, attack_type=attack_type)
PadFile.__init__(self, client_id=1, path=path,
file_id=id, attack_type=attack_type)
def load(self, directory=None, extension='.hdf5'):
"""Loads the data at the specified location and using the given extension.
......@@ -57,6 +59,7 @@ class TestFile(PadFile):
path = self.make_path(directory or '', extension or '')
return dummy_data[os.path.basename(path)]
def dumplist(args):
"""Dumps lists of files based on your criteria"""
......@@ -76,6 +79,7 @@ def dumplist(args):
class Interface(BaseInterface):
def name(self):
return dummy_name
......@@ -96,11 +100,11 @@ class Interface(BaseInterface):
dumpparser = subparsers.add_parser('dumplist', help="")
dumpparser.add_argument('-d', '--directory', dest="directory", default='',
help="if given, this path will be prepended to every entry returned (defaults to '%(default)s')")
help="if given, this path will be prepended to every entry returned (defaults to '%(default)s')")
dumpparser.add_argument('-e', '--extension', dest="extension", default='',
help="if given, this extension will be appended to every entry returned (defaults to '%(default)s')")
help="if given, this extension will be appended to every entry returned (defaults to '%(default)s')")
dumpparser.add_argument('--self-test', dest="selftest", default=False,
action='store_true', help=SUPPRESS)
action='store_true', help=SUPPRESS)
dumpparser.set_defaults(func=dumplist) # action
......@@ -110,9 +114,9 @@ class TestDatabase(PadDatabase):
def __init__(self, protocol='Default', original_directory=data_dir, original_extension='', **kwargs):
# call base class constructors to open a session to the database
PadDatabase.__init__(self, name='testspoof', protocol=protocol,
original_directory=original_directory,
original_extension=original_extension, **kwargs)
super(TestDatabase, self).__init__(name='testspoof', protocol=protocol,
original_directory=original_directory,
original_extension=original_extension, **kwargs)
################################################
# Low level support methods for the database #
......@@ -168,6 +172,7 @@ class TestDatabase(PadDatabase):
# does not implement the given access protocol
return False
def get_all_data(self):
return self.all_files()
......
......@@ -45,10 +45,10 @@ def create_database():
del engine
class TestDatabaseSql (bob.pad.base.database.PadDatabase, bob.db.base.SQLiteDatabase):
class TestDatabaseSql (bob.pad.base.database.PadDatabase, bob.db.base.SQLiteBaseDatabase):
def __init__(self):
bob.pad.base.database.PadDatabase.__init__(self, 'pad_test', original_directory="original/directory", original_extension=".orig")
bob.db.base.SQLiteDatabase.__init__(self, dbfile, TestFileSql)
bob.db.base.SQLiteBaseDatabase.__init__(self, dbfile, TestFileSql)
def groups(self, protocol=None):
return ['group']
......@@ -59,4 +59,4 @@ class TestDatabaseSql (bob.pad.base.database.PadDatabase, bob.db.base.SQLiteData
def annotations(self, file):
return None
database = TestDatabaseSql()
\ No newline at end of file
database = TestDatabaseSql()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment