diff --git a/bob/db/cuhk_cufsf/query.py b/bob/db/cuhk_cufsf/query.py index efaffb4510e443c45db25d22c243b652aebac915..023f60f5685ace7485057268f7bfa8c2ec3c110b 100755 --- a/bob/db/cuhk_cufsf/query.py +++ b/bob/db/cuhk_cufsf/query.py @@ -119,7 +119,7 @@ class Database(bob.db.base.SQLiteDatabase): return file.annotations(annotation_type=annotation_type) - def objects(self, groups = None, protocol = None, purposes = None, model_ids = None, **kwargs): + def objects(self, groups = None, protocol = None, purposes = None, model_ids = None, modality=None, **kwargs): """ This function returns lists of File objects, which fulfill the given restrictions. @@ -129,6 +129,9 @@ class Database(bob.db.base.SQLiteDatabase): groups = self.check_parameters_for_validity(groups, "group", GROUPS) protocols = self.check_parameters_for_validity(protocol, "protocol", PROTOCOLS) purposes = self.check_parameters_for_validity(purposes, "purpose", PURPOSES) + modality = self.check_parameters_for_validity( + modality, "modality", self.modalities) + #You need to select only one protocol if (len(protocols) > 1): @@ -141,6 +144,9 @@ class Database(bob.db.base.SQLiteDatabase): query = query.filter(bob.db.cuhk_cufsf.Protocol_File_Association.group.in_(groups)) query = query.filter(bob.db.cuhk_cufsf.Protocol_File_Association.protocol.in_(protocols)) query = query.filter(bob.db.cuhk_cufsf.Protocol_File_Association.purpose.in_(purposes)) + query = query.filter( + bob.db.cuhk_cufsf.File.modality.in_(modality)) + if model_ids is not None and not 'probe' in purposes: if type(model_ids) is not list and type(model_ids) is not tuple: diff --git a/bob/db/cuhk_cufsf/test.py b/bob/db/cuhk_cufsf/test.py index 793af61260faadc762c929ce33737d19fd501470..0cd79fee3067c6b550c99a1db07d282e75aba12d 100644 --- a/bob/db/cuhk_cufsf/test.py +++ b/bob/db/cuhk_cufsf/test.py @@ -76,6 +76,10 @@ def test02_search_files_protocols(): assert len(bob.db.cuhk_cufsf.Database().objects(protocol=p, groups="eval")) == 0 + # Checking the modalities + assert len(bob.db.cuhk_cufsf.Database().objects(protocol=p, groups="world", modality=["photo"])) == world//2 + assert len(bob.db.cuhk_cufsf.Database().objects(protocol=p, groups="world", modality=["sketch"])) == world//2 + def test03_verification_files_protocols():