diff --git a/bob/db/cuhk_cufs/create.py b/bob/db/cuhk_cufs/create.py index 4b0ef1bda45de776e347981825bfd97b6b834a26..0c372a5f2d33d122434bb359cc056c34ac55dd1c 100644 --- a/bob/db/cuhk_cufs/create.py +++ b/bob/db/cuhk_cufs/create.py @@ -221,31 +221,30 @@ def add_protocols(session, verbose, photo2sketch=True): if verbose>=1: print('Creating the protocol ARFACE ...') - #getting the files - world_files = arface.get_files_from_group(group="world") - dev_files = arface.get_files_from_group(group="dev") - eval_files = arface.get_files_from_group(group="eval") + #getting the clients + world_clients = arface.get_clients_from_group(group="world") + insert_protocol_data(session, "arface"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch) - #Inserting in the database - insert_protocol_data(session, "arface"+suffix, "world", "train", world_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "arface"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "arface"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch) - + dev_clients = arface.get_clients_from_group(group="dev") + insert_protocol_data(session, "arface"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch) + + eval_clients = arface.get_clients_from_group(group="eval") + insert_protocol_data(session, "arface"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch) + session.commit() - - + ############## Protocol XM2VTS if verbose>=1: print('Creating the protocol XM2VTS ...') #getting the files - world_files = xm2vts.get_files_from_group(group="world") - dev_files = xm2vts.get_files_from_group(group="dev") - eval_files = xm2vts.get_files_from_group(group="eval") + world_clients = xm2vts.get_clients_from_group(group="world") + dev_clients = xm2vts.get_clients_from_group(group="dev") + eval_clients = xm2vts.get_clients_from_group(group="eval") #Inserting in the database - insert_protocol_data(session, "xm2vts"+suffix, "world", "train", world_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "xm2vts"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "xm2vts"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch) + insert_protocol_data(session, "xm2vts"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "xm2vts"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "xm2vts"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch) session.commit() @@ -256,14 +255,14 @@ def add_protocols(session, verbose, photo2sketch=True): if verbose>=1: print('Creating the protocol CUHK ...') #getting the files - world_files = cuhk.get_files_from_group(group="world") - dev_files = cuhk.get_files_from_group(group="dev") - eval_files = cuhk.get_files_from_group(group="eval") + world_clients = cuhk.get_clients_from_group(group="world") + dev_clients = cuhk.get_clients_from_group(group="dev") + eval_clients = cuhk.get_clients_from_group(group="eval") #Inserting in the database - insert_protocol_data(session, "cuhk"+suffix, "world", "train", world_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "cuhk"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "cuhk"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch) + insert_protocol_data(session, "cuhk"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "cuhk"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "cuhk"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch) session.commit() @@ -273,23 +272,23 @@ def add_protocols(session, verbose, photo2sketch=True): if verbose>=1: print('Creating the protocol ALL mixed ...') #getting the files - world_files = arface.get_files_from_group(group="world") +\ - xm2vts.get_files_from_group(group="world") +\ - cuhk.get_files_from_group(group="world") + world_clients = arface.get_clients_from_group(group="world") +\ + xm2vts.get_clients_from_group(group="world") +\ + cuhk.get_clients_from_group(group="world") - dev_files = arface.get_files_from_group(group="dev") +\ - xm2vts.get_files_from_group(group="dev") +\ - cuhk.get_files_from_group(group="dev") + dev_clients = arface.get_clients_from_group(group="dev") +\ + xm2vts.get_clients_from_group(group="dev") +\ + cuhk.get_clients_from_group(group="dev") - eval_files = arface.get_files_from_group(group="eval") +\ - xm2vts.get_files_from_group(group="eval") +\ - cuhk.get_files_from_group(group="eval") + eval_clients = arface.get_clients_from_group(group="eval") +\ + xm2vts.get_clients_from_group(group="eval") +\ + cuhk.get_clients_from_group(group="eval") #Inserting in the database - insert_protocol_data(session, "all-mixed"+suffix, "world", "train", world_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "all-mixed"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "all-mixed"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch) + insert_protocol_data(session, "all-mixed"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "all-mixed"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "all-mixed"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch) session.commit() @@ -299,15 +298,15 @@ def add_protocols(session, verbose, photo2sketch=True): if verbose>=1: print('Creating the protocol cuhk-arface-xm2vts ...') #getting the files - world_files = cuhk.get_files_from_group(group="world") - dev_files = arface.get_files_from_group(group="dev") - eval_files = xm2vts.get_files_from_group(group="eval") + world_clients = cuhk.get_clients_from_group(group="world") + dev_clients = arface.get_clients_from_group(group="dev") + eval_clients = xm2vts.get_clients_from_group(group="eval") #Inserting in the database - insert_protocol_data(session, "cuhk-arface-xm2vts"+suffix, "world", "train", world_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "cuhk-arface-xm2vts"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "cuhk-arface-xm2vts"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch) + insert_protocol_data(session, "cuhk-arface-xm2vts"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "cuhk-arface-xm2vts"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "cuhk-arface-xm2vts"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch) session.commit() @@ -317,15 +316,15 @@ def add_protocols(session, verbose, photo2sketch=True): if verbose>=1: print('Creating the protocol cuhk-xm2vts-arface ...') #getting the files - world_files = cuhk.get_files_from_group(group="world") - dev_files = xm2vts.get_files_from_group(group="dev") - eval_files = arface.get_files_from_group(group="eval") + world_clients = cuhk.get_clients_from_group(group="world") + dev_clients = xm2vts.get_clients_from_group(group="dev") + eval_clients = arface.get_clients_from_group(group="eval") #Inserting in the database - insert_protocol_data(session, "cuhk-xm2vts-arface"+suffix, "world", "train", world_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "cuhk-xm2vts-arface"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "cuhk-xm2vts-arface"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch) + insert_protocol_data(session, "cuhk-xm2vts-arface"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "cuhk-xm2vts-arface"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "cuhk-xm2vts-arface"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch) session.commit() @@ -335,15 +334,15 @@ def add_protocols(session, verbose, photo2sketch=True): if verbose>=1: print('Creating the protocol arface-cuhk-xm2vts ...') #getting the files - world_files = arface.get_files_from_group(group="world") - dev_files = cuhk.get_files_from_group(group="dev") - eval_files = xm2vts.get_files_from_group(group="eval") + world_clients = arface.get_clients_from_group(group="world") + dev_clients = cuhk.get_clients_from_group(group="dev") + eval_clients = xm2vts.get_clients_from_group(group="eval") #Inserting in the database - insert_protocol_data(session, "arface-cuhk-xm2vts"+suffix, "world", "train", world_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "arface-cuhk-xm2vts"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "arface-cuhk-xm2vts"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch) + insert_protocol_data(session, "arface-cuhk-xm2vts"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "arface-cuhk-xm2vts"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "arface-cuhk-xm2vts"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch) session.commit() @@ -353,15 +352,15 @@ def add_protocols(session, verbose, photo2sketch=True): if verbose>=1: print('Creating the protocol arface-xm2vts-cuhk ...') #getting the files - world_files = arface.get_files_from_group(group="world") - dev_files = xm2vts.get_files_from_group(group="dev") - eval_files = cuhk.get_files_from_group(group="eval") + world_clients = arface.get_clients_from_group(group="world") + dev_clients = xm2vts.get_clients_from_group(group="dev") + eval_clients = cuhk.get_clients_from_group(group="eval") #Inserting in the database - insert_protocol_data(session, "arface-xm2vts-cuhk"+suffix, "world", "train", world_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "arface-xm2vts-cuhk"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "arface-xm2vts-cuhk"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch) + insert_protocol_data(session, "arface-xm2vts-cuhk"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "arface-xm2vts-cuhk"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "arface-xm2vts-cuhk"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch) session.commit() @@ -371,15 +370,15 @@ def add_protocols(session, verbose, photo2sketch=True): if verbose>=1: print('Creating the protocol xm2vts-cuhk-arface ...') #getting the files - world_files = xm2vts.get_files_from_group(group="world") - dev_files = cuhk.get_files_from_group(group="dev") - eval_files = arface.get_files_from_group(group="eval") + world_clients = xm2vts.get_clients_from_group(group="world") + dev_clients = cuhk.get_clients_from_group(group="dev") + eval_clients = arface.get_clients_from_group(group="eval") #Inserting in the database - insert_protocol_data(session, "xm2vts-cuhk-arface"+suffix, "world", "train", world_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "xm2vts-cuhk-arface"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "xm2vts-cuhk-arface"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch) + insert_protocol_data(session, "xm2vts-cuhk-arface"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "xm2vts-cuhk-arface"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "xm2vts-cuhk-arface"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch) session.commit() @@ -390,41 +389,100 @@ def add_protocols(session, verbose, photo2sketch=True): if verbose>=1: print('Creating the protocol xm2vts-arface-cuhk ...') #getting the files - world_files = xm2vts.get_files_from_group(group="world") - dev_files = arface.get_files_from_group(group="dev") - eval_files = cuhk.get_files_from_group(group="eval") + world_clients = xm2vts.get_clients_from_group(group="world") + dev_clients = arface.get_clients_from_group(group="dev") + eval_clients = cuhk.get_clients_from_group(group="eval") #Inserting in the database - insert_protocol_data(session, "xm2vts-arface-cuhk"+suffix, "world", "train", world_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "xm2vts-arface-cuhk"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch) - insert_protocol_data(session, "xm2vts-arface-cuhk"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch) + insert_protocol_data(session, "xm2vts-arface-cuhk"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "xm2vts-arface-cuhk"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch) + insert_protocol_data(session, "xm2vts-arface-cuhk"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch) session.commit() -def insert_protocol_data(session, protocol, group, purpose, file_objects, photo2sketch=True): - - for f in file_objects: - if purpose!="train": - if photo2sketch: +def insert_protocol_data(session, protocol, group, purpose, clients, photo2sketch=True): - if f.modality=="photo": - purpose = "enroll" + for c in clients: + if purpose=="train": + #Adding files for training + for f in c.files: + session.add(bob.db.cuhk_cufs.Protocol_File_Association( + protocol, group, purpose, f.id, c.id)) + + else: + #PRobing + for c1 in clients: + + #With same clients, you need to define which one is enroll or probe + if c1.id==c.id: + for f in c1.files: + if ((photo2sketch and f.modality=="photo") or + (not photo2sketch and f.modality=="sketch")): + + purpose="enroll" + else: + purpose="probe" + + session.add(bob.db.cuhk_cufs.Protocol_File_Association( + protocol, group, purpose, f.id, c.id)) + else: - purpose = "probe" + #With different clients the task is to define which file should be included - else: - if f.modality=="photo": - purpose = "probe" - else: - purpose = "enroll" + for f in c1.files: + if ((photo2sketch and f.modality=="photo") or + (not photo2sketch and f.modality=="sketch")): + continue #Excluding files + else: + purpose="probe" + + session.add(bob.db.cuhk_cufs.Protocol_File_Association( + protocol, group, purpose, f.id, c.id)) + + """ + for f in file_objects: + if purpose=="train": + session.add(bob.db.cuhk_cufs.Protocol_File_Association( + protocol, group, purpose, f.id, f.id)) + else: + + #Excluding some files + if ((not photo2sketch and f.modality=="photo") or + (photo2sketch and f.modality=="sketch") ) + continue + + #probing + for f1 in file_objects: + + #Same client + if f1.client_id == f.client_id: + if photo2sketch: + if f.modality=="photo": + purpose = "enroll" + else: + purpose = "probe" + + else: + if f.modality=="photo": + purpose = "probe" + else: + purpose = "enroll" + + else:#Different client + + if ((photo2sketch and f1.modality=="sketch") or + (not photo2sketch and f1.modality=="photo")): + purpose = "probe" + else: + continue - session.add(bob.db.cuhk_cufs.Protocol_File_Association( - protocol, group, purpose, f.id)) - + session.add(bob.db.cuhk_cufs.Protocol_File_Association( + protocol, group, purpose, f.id, f1.id)) + """ def create_tables(args): diff --git a/bob/db/cuhk_cufs/db.sql3 b/bob/db/cuhk_cufs/db.sql3 index bf1720db86c78e40f3e836ba96fd8ff1bea29675..73c4bd6ced6b95269e7e87bb6741236402faa34c 100644 Binary files a/bob/db/cuhk_cufs/db.sql3 and b/bob/db/cuhk_cufs/db.sql3 differ diff --git a/bob/db/cuhk_cufs/models.py b/bob/db/cuhk_cufs/models.py index e21bbcb6997048b07233a60ff19617dd7b016df1..eeb12d8a866ab6f8c6b038576894e7c57f622cc9 100644 --- a/bob/db/cuhk_cufs/models.py +++ b/bob/db/cuhk_cufs/models.py @@ -58,13 +58,16 @@ class Protocol_File_Association(Base): protocol = Column('protocol', Enum(*PROTOCOLS), primary_key=True) group = Column('group', Enum(*GROUPS), primary_key=True) purpose = Column('purpose', Enum(*PURPOSES), primary_key=True) - file_id = Column('file_id', Integer, ForeignKey('file.id'), primary_key=True) - def __init__(self, protocol, group, purpose, file_id): - self.protocol = protocol - self.group = group - self.purpose = purpose - self.file_id = file_id + file_id = Column('file_id', Integer, ForeignKey('file.id'), primary_key=True) + client_id = Column('client_id', Integer, ForeignKey('client.id'), primary_key=True) + + def __init__(self, protocol, group, purpose, file_id, client_id): + self.protocol = protocol + self.group = group + self.purpose = purpose + self.file_id = file_id + self.client_id = client_id diff --git a/bob/db/cuhk_cufs/test.py b/bob/db/cuhk_cufs/test.py index dfc816071706dfc35db6f24e48edceb7a95cbca5..041e731ad7cb56596ea7f0f803e0f2d0378738ed 100644 --- a/bob/db/cuhk_cufs/test.py +++ b/bob/db/cuhk_cufs/test.py @@ -55,16 +55,16 @@ def test01_protocols_purposes_groups(): def test02_all_files_protocols(): - cuhk = 376 - arface = 246 - xm2vts = 590 - all_mixed = 1212 - cuhk_arface_xm2vts = 408 - cuhk_xm2vts_arface = 404 - arface_cuhk_xm2vts = 378 - arface_xm2vts_cuhk = 378 - xm2vts_cuhk_arface = 426 - xm2vts_arface_cuhk = 430 + cuhk = 6648 + arface = 3288 + xm2vts = 16078 + all_mixed = 68924 + cuhk_arface_xm2vts = 9800 + cuhk_xm2vts_arface = 9542 + arface_cuhk_xm2vts = 11290 + arface_xm2vts_cuhk = 11226 + xm2vts_cuhk_arface = 4988 + xm2vts_arface_cuhk = 5182 assert len(bob.db.cuhk_cufs.Database().objects(protocol="cuhk_p2s")) == cuhk assert len(bob.db.cuhk_cufs.Database().objects(protocol="cuhk_s2p")) == cuhk @@ -93,6 +93,9 @@ def test02_all_files_protocols(): assert len(bob.db.cuhk_cufs.Database().objects(protocol="xm2vts-cuhk-arface_p2s")) == xm2vts_cuhk_arface assert len(bob.db.cuhk_cufs.Database().objects(protocol="xm2vts-cuhk-arface_s2p")) == xm2vts_cuhk_arface + assert len(bob.db.cuhk_cufs.Database().objects(protocol="xm2vts-arface-cuhk_p2s")) == xm2vts_arface_cuhk + assert len(bob.db.cuhk_cufs.Database().objects(protocol="xm2vts-arface-cuhk_s2p")) == xm2vts_arface_cuhk + def test03_world_files_protocols(): @@ -138,10 +141,10 @@ def test03_world_files_protocols(): def test04_dev_files_protocols(): - cuhk = 112 - arface = 80 - xm2vts = 176 - all_mixed = 368 + cuhk = 3192 + arface = 1640 + xm2vts = 7832 + all_mixed = 34040 cuhk_arface_xm2vts = arface cuhk_xm2vts_arface = xm2vts arface_cuhk_xm2vts = cuhk @@ -180,10 +183,10 @@ def test04_dev_files_protocols(): def test05_eval_files_protocols(): - cuhk = 114 - arface = 78 - xm2vts = 178 - all_mixed = 370 + cuhk = 3306 + arface = 1560 + xm2vts = 8010 + all_mixed = 34410 cuhk_arface_xm2vts = xm2vts cuhk_xm2vts_arface = arface arface_cuhk_xm2vts = xm2vts diff --git a/bob/db/cuhk_cufs/utils.py b/bob/db/cuhk_cufs/utils.py index 59fc3167f23affc212564ee88a4bb7f61090b75d..6f1d08625001685de15b4bc4030a7ab182509931 100644 --- a/bob/db/cuhk_cufs/utils.py +++ b/bob/db/cuhk_cufs/utils.py @@ -66,7 +66,7 @@ class ARFACEWrapper(): return 'man' if client_id[0]=='m' else 'woman' - def get_files_from_group(self, group=""): + def get_clients_from_group(self, group=""): """ Get the bob.db.cuhk_cufs.File for a given group (world, dev or eval). @@ -78,18 +78,19 @@ class ARFACEWrapper(): """ arface = bob.db.arface.Database() cuhk = bob.db.cuhk_cufs.Database() + import sqlalchemy #Getting the clients from ARFACE - clients = arface.query(bob.db.arface.Client).filter(bob.db.arface.Client.sgroup==group) + original_clients = arface.query(bob.db.arface.Client).filter(bob.db.arface.Client.sgroup==group) #Getting the correspondent files from bob.db.cuhk_cufs - files = [] - for c in clients: - cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_id==c.id) - for f in cuhk_files: - files.append(f) + clients = [] + for o in original_clients: + cuhk_clients = cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_id==o.id).options(sqlalchemy.orm.subqueryload(bob.db.cuhk_cufs.Client.files)) #forcing to bring the clients + for c in cuhk_clients: + clients.append(c) - return files + return list(clients) def get_annotations(self, annotation_dir, annotation_extension='.dat'): @@ -212,7 +213,7 @@ class XM2VTSWrapper(): def get_gender(self): return 'none' - def get_files_from_group(self, group=""): + def get_clients_from_group(self, group=""): """ This is a hand made protocol since the XM2VTS database is biased. @@ -225,8 +226,9 @@ class XM2VTSWrapper(): indexes = [273, 241, 285, 256, 173, 193, 107, 55, 53, 143, 163, 63, 13, 113, 258, 271, 134, 17, 20, 227, 203, 96, 66, 112, 77, 237, 42, 61, 272, 161, 209, 206, 195, 140, 150, 294, 152, 136, 188, 232, 21, 75, 141, 25, 249, 269, 70, 217, 251, 29, 153, 83, 185, 94, 116, 265, 177, 38, 156, 191, 118, 121, 204, 100, 255, 286, 78, 260, 282, 33, 242, 200, 91, 224, 137, 180, 65, 12, 3, 151, 154, 1, 290, 198, 167, 212, 72, 133, 144, 57, 0, 211, 48, 292, 213, 277, 52, 223, 115, 230, 49, 4, 291, 214, 18, 71, 146, 289, 250, 268, 201, 170, 11, 178, 2, 155, 264, 64, 287, 14, 110, 30, 19, 149, 68, 183, 44, 60, 181, 283, 86, 139, 81, 126, 202, 120, 10, 9, 164, 218, 43, 148, 105, 186, 225, 93, 184, 50, 257, 132, 254, 27, 108, 106, 69, 252, 138, 122, 196, 175, 228, 7, 168, 135, 15, 231, 182, 280, 147, 54, 261, 79, 281, 125, 142, 101, 259, 41, 187, 16, 275, 248, 179, 169, 89, 245, 26, 73, 199, 90, 128, 236, 40, 166, 262, 84, 32, 97, 92, 174, 284, 37, 36, 111, 82, 104, 58, 98, 235, 215, 220, 130, 85, 216, 205, 274, 22, 244, 129, 247, 6, 240, 279, 5, 109, 31, 74, 127, 95, 117, 210, 165, 80, 59, 114, 194, 238, 207, 239, 267, 159, 243, 131, 171, 67, 222, 8, 47, 45, 99, 123, 229, 293, 270, 253, 46, 162, 263, 102, 76, 88, 28, 158, 278, 62, 246, 176, 124, 234, 276, 87, 24, 157, 119, 197, 190, 35, 34, 160, 56, 266, 172, 39, 233, 221, 192, 288, 23, 226, 219, 189, 208, 145, 103, 51] #Fetching the clients + import sqlalchemy cuhk = bob.db.cuhk_cufs.Database() - all_clients = numpy.array(cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_database=="xm2vts").order_by(bob.db.cuhk_cufs.Client.original_id).all()) + all_clients = numpy.array(cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_database=="xm2vts").order_by(bob.db.cuhk_cufs.Client.original_id).options(sqlalchemy.orm.subqueryload(bob.db.cuhk_cufs.Client.files)).all()) #forcing to bring the clients. data_training = 118 data_dev = 88 @@ -245,13 +247,13 @@ class XM2VTSWrapper(): #Fetching the correspondent files from bob.db.cuhk_cufs - files = [] - for c in clients: - cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.id==c.id) - for f in cuhk_files: - files.append(f) + #files = [] + #for c in clients: + #cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.id==c.id) + #for f in cuhk_files: + #files.append(f) - return files + return list(clients) def get_files_from_group_biased(self, group=""): @@ -459,7 +461,7 @@ class CUHKWrapper(): return annotations - def get_files_from_group(self, group=""): + def get_clients_from_group(self, group=""): """ This is a hand made protocol since there is no protocol for the CUHK-CUFS database. @@ -467,14 +469,14 @@ class CUHKWrapper(): - 40% for training --> 75 - 30% for developement --> 56 - 30% for testing --> 57 - """ - + """ indexes = [152, 70, 150, 120, 181, 64, 16, 66, 154, 1, 84, 35, 179, 105, 49, 159, 128, 14, 103, 157, 18, 148, 88, 134, 147, 72, 62, 110, 20, 27, 30, 187, 50, 117, 83, 71, 81, 61, 185, 85, 2, 145, 138, 45, 129, 151, 96, 132, 146, 87, 156, 173, 73, 38, 125, 69, 82, 34, 116, 102, 136, 91, 7, 143, 109, 112, 115, 63, 33, 165, 104, 170, 76, 36, 114, 5, 142, 90, 60, 40, 93, 67, 180, 77, 106, 130, 135, 124, 118, 6, 39, 97, 121, 4, 74, 86, 57, 24, 65, 167, 184, 163, 47, 169, 94, 8, 58, 126, 166, 15, 172, 11, 89, 162, 42, 98, 22, 133, 78, 175, 0, 160, 92, 37, 161, 17, 26, 122, 137, 164, 99, 149, 32, 95, 144, 46, 155, 168, 48, 182, 23, 80, 10, 140, 9, 55, 29, 113, 12, 54, 158, 52, 41, 119, 183, 25, 131, 107, 176, 31, 111, 108, 123, 79, 153, 178, 139, 51, 13, 177, 141, 171, 101, 3, 43, 68, 56, 21, 75, 28, 53, 44, 19, 174, 100, 127, 186, 59] #Fetching the clients cuhk = bob.db.cuhk_cufs.Database() - all_clients = numpy.array(cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_database=="cuhk").order_by(bob.db.cuhk_cufs.Client.id).all()) + import sqlalchemy + all_clients = numpy.array(cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_database=="cuhk").order_by(bob.db.cuhk_cufs.Client.id).options(sqlalchemy.orm.subqueryload(bob.db.cuhk_cufs.Client.files)).all()) data_training = 75 data_dev = 56 @@ -492,13 +494,13 @@ class CUHKWrapper(): clients = all_clients[indexes[offset:offset+data_eval]] #Fetching the correspondent files from bob.db.cuhk_cufs - files = [] - for c in clients: - cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.id==c.id) - for f in cuhk_files: - files.append(f) + #files = [] + #for c in clients: + #cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.id==c.id) + #for f in cuhk_files: + #files.append(f) - return files + return list(clients)