diff --git a/bob/db/cuhk_cufs/query.py b/bob/db/cuhk_cufs/query.py index 48e53e5f32e5e4f50af80116f2da3d5309abc7b9..dc5f04ceed3f31e98a19e94b8097f292e266c572 100644 --- a/bob/db/cuhk_cufs/query.py +++ b/bob/db/cuhk_cufs/query.py @@ -125,10 +125,17 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti query = query.filter(bob.db.cuhk_cufs.Protocol_File_Association.purpose.in_(purposes)) if model_ids is not None: - if type(model_ids) is not list and type(model_ids) is not tuple: - model_ids = [model_ids] - - query = query.filter(bob.db.cuhk_cufs.Client.id.in_(model_ids)) + if type(model_ids) is not list and type(model_ids) is not tuple: + model_ids = [model_ids] + + #if you provide a client object as input and not the ids + if type(model_ids[0]) is bob.db.cuhk_cufs.Client: + model_aux = [] + for m in model_ids: + model_aux.append(m.id) + model_ids = model_aux + + query = query.filter(bob.db.cuhk_cufs.Client.id.in_(model_ids)) raw_files = query.all() files = [] @@ -158,7 +165,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti query = query.filter(bob.db.cuhk_cufs.Protocol_File_Association.group.in_(groups)) query = query.filter(bob.db.cuhk_cufs.Protocol_File_Association.protocol.in_(protocols)) - return query.all() + return [c.id for c in query.all()] def groups(self, protocol = None, **kwargs):