diff --git a/bob/bio/base/test/dummy/database.py b/bob/bio/base/test/dummy/database.py index f8061c6f4c9ec21e5ca2d84bf1a02489e851c04c..32fa5627a37fd13d7d8762998c115d9f828cd19e 100644 --- a/bob/bio/base/test/dummy/database.py +++ b/bob/bio/base/test/dummy/database.py @@ -1,4 +1,4 @@ -from bob.bio.db import ZTBioDatabase, AtntBioDatabase +from bob.bio.db import ZTBioDatabase, BioFile from bob.bio.base.test.utils import atnt_database_directory @@ -14,13 +14,17 @@ class DummyDatabase(ZTBioDatabase): training_depends_on_protocol=False, models_depend_on_protocol=False ) - self.__db = AtntBioDatabase() + import bob.db.atnt + self.__db = bob.db.atnt.Database() + + def _make_bio(self, files): + return [BioFile(client_id=f.client_id, path=f.path, file_id=f.id) for f in files] def model_ids_with_protocol(self, groups=None, protocol=None, **kwargs): - return self.__db.model_ids_with_protocol(groups, protocol) + return self.__db.model_ids(groups, protocol) def objects(self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs): - return self.__db.objects(groups, protocol, purposes, model_ids, **kwargs) + return self._make_bio(self.__db.objects(model_ids, groups, purposes, protocol, **kwargs)) def tobjects(self, groups=None, protocol=None, model_ids=None, **kwargs): return [] @@ -29,7 +33,7 @@ class DummyDatabase(ZTBioDatabase): return [] def tmodel_ids_with_protocol(self, protocol=None, groups=None, **kwargs): - return self.__db.model_ids_with_protocol(groups, protocol) + return self.__db.model_ids(groups) def t_enroll_files(self, t_model_id, group='dev'): return self.enroll_files(t_model_id, group) diff --git a/bob/bio/base/test/dummy/fileset.py b/bob/bio/base/test/dummy/fileset.py index 570f2eb085e972b6cbeb6c7a742a2f337e52e808..ba0fdf6a4b096ea7233cfc0cf0525ffe373e9b9b 100644 --- a/bob/bio/base/test/dummy/fileset.py +++ b/bob/bio/base/test/dummy/fileset.py @@ -1,7 +1,6 @@ from bob.bio.db import ZTBioDatabase, BioFileSet, BioFile from bob.bio.base.test.utils import atnt_database_directory - class DummyDatabase(ZTBioDatabase): def __init__(self): @@ -20,25 +19,21 @@ class DummyDatabase(ZTBioDatabase): def uses_probe_file_sets(self): return True + def _make_bio(self, files): + return [BioFile(client_id=f.client_id, path=f.path, file_id=f.id) for f in files] + def probe_file_sets(self, model_id=None, group='dev'): """Returns the list of probe File objects (for the given model id, if given).""" - # import ipdb; ipdb.set_trace() files = self.arrange_by_client(self.sort(self.objects(protocol=None, groups=group, purposes='probe'))) # arrange files by clients - file_sets = [] - for client_files in files: - # convert into our File objects (so that they are tested as well) - our_files = [BioFile(f.client_id, f.path, f.id) for f in client_files] - # generate file set for each client - file_set = BioFileSet(our_files[0].client_id, our_files) - file_sets.append(file_set) + file_sets = [BioFileSet(client_files[0].client_id, client_files) for client_files in files] return file_sets def model_ids_with_protocol(self, groups=None, protocol=None, **kwargs): return self.__db.model_ids(groups, protocol) def objects(self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs): - return self.__db.objects(model_ids, groups, purposes, protocol, **kwargs) + return self._make_bio(self.__db.objects(model_ids, groups, purposes, protocol, **kwargs)) def tobjects(self, groups=None, protocol=None, model_ids=None, **kwargs): return []