Commit dd3e282b authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed several bugs

parent be1360b9
......@@ -221,31 +221,30 @@ def add_protocols(session, verbose, photo2sketch=True):
if verbose>=1: print('Creating the protocol ARFACE ...')
#getting the files
world_files = arface.get_files_from_group(group="world")
dev_files = arface.get_files_from_group(group="dev")
eval_files = arface.get_files_from_group(group="eval")
#getting the clients
world_clients = arface.get_clients_from_group(group="world")
insert_protocol_data(session, "arface"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch)
#Inserting in the database
insert_protocol_data(session, "arface"+suffix, "world", "train", world_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "arface"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "arface"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch)
dev_clients = arface.get_clients_from_group(group="dev")
insert_protocol_data(session, "arface"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch)
eval_clients = arface.get_clients_from_group(group="eval")
insert_protocol_data(session, "arface"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch)
session.commit()
############## Protocol XM2VTS
if verbose>=1: print('Creating the protocol XM2VTS ...')
#getting the files
world_files = xm2vts.get_files_from_group(group="world")
dev_files = xm2vts.get_files_from_group(group="dev")
eval_files = xm2vts.get_files_from_group(group="eval")
world_clients = xm2vts.get_clients_from_group(group="world")
dev_clients = xm2vts.get_clients_from_group(group="dev")
eval_clients = xm2vts.get_clients_from_group(group="eval")
#Inserting in the database
insert_protocol_data(session, "xm2vts"+suffix, "world", "train", world_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch)
session.commit()
......@@ -256,14 +255,14 @@ def add_protocols(session, verbose, photo2sketch=True):
if verbose>=1: print('Creating the protocol CUHK ...')
#getting the files
world_files = cuhk.get_files_from_group(group="world")
dev_files = cuhk.get_files_from_group(group="dev")
eval_files = cuhk.get_files_from_group(group="eval")
world_clients = cuhk.get_clients_from_group(group="world")
dev_clients = cuhk.get_clients_from_group(group="dev")
eval_clients = cuhk.get_clients_from_group(group="eval")
#Inserting in the database
insert_protocol_data(session, "cuhk"+suffix, "world", "train", world_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch)
session.commit()
......@@ -273,23 +272,23 @@ def add_protocols(session, verbose, photo2sketch=True):
if verbose>=1: print('Creating the protocol ALL mixed ...')
#getting the files
world_files = arface.get_files_from_group(group="world") +\
xm2vts.get_files_from_group(group="world") +\
cuhk.get_files_from_group(group="world")
world_clients = arface.get_clients_from_group(group="world") +\
xm2vts.get_clients_from_group(group="world") +\
cuhk.get_clients_from_group(group="world")
dev_files = arface.get_files_from_group(group="dev") +\
xm2vts.get_files_from_group(group="dev") +\
cuhk.get_files_from_group(group="dev")
dev_clients = arface.get_clients_from_group(group="dev") +\
xm2vts.get_clients_from_group(group="dev") +\
cuhk.get_clients_from_group(group="dev")
eval_files = arface.get_files_from_group(group="eval") +\
xm2vts.get_files_from_group(group="eval") +\
cuhk.get_files_from_group(group="eval")
eval_clients = arface.get_clients_from_group(group="eval") +\
xm2vts.get_clients_from_group(group="eval") +\
cuhk.get_clients_from_group(group="eval")
#Inserting in the database
insert_protocol_data(session, "all-mixed"+suffix, "world", "train", world_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "all-mixed"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "all-mixed"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "all-mixed"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "all-mixed"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "all-mixed"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch)
session.commit()
......@@ -299,15 +298,15 @@ def add_protocols(session, verbose, photo2sketch=True):
if verbose>=1: print('Creating the protocol cuhk-arface-xm2vts ...')
#getting the files
world_files = cuhk.get_files_from_group(group="world")
dev_files = arface.get_files_from_group(group="dev")
eval_files = xm2vts.get_files_from_group(group="eval")
world_clients = cuhk.get_clients_from_group(group="world")
dev_clients = arface.get_clients_from_group(group="dev")
eval_clients = xm2vts.get_clients_from_group(group="eval")
#Inserting in the database
insert_protocol_data(session, "cuhk-arface-xm2vts"+suffix, "world", "train", world_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk-arface-xm2vts"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk-arface-xm2vts"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk-arface-xm2vts"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk-arface-xm2vts"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk-arface-xm2vts"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch)
session.commit()
......@@ -317,15 +316,15 @@ def add_protocols(session, verbose, photo2sketch=True):
if verbose>=1: print('Creating the protocol cuhk-xm2vts-arface ...')
#getting the files
world_files = cuhk.get_files_from_group(group="world")
dev_files = xm2vts.get_files_from_group(group="dev")
eval_files = arface.get_files_from_group(group="eval")
world_clients = cuhk.get_clients_from_group(group="world")
dev_clients = xm2vts.get_clients_from_group(group="dev")
eval_clients = arface.get_clients_from_group(group="eval")
#Inserting in the database
insert_protocol_data(session, "cuhk-xm2vts-arface"+suffix, "world", "train", world_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk-xm2vts-arface"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk-xm2vts-arface"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk-xm2vts-arface"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk-xm2vts-arface"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "cuhk-xm2vts-arface"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch)
session.commit()
......@@ -335,15 +334,15 @@ def add_protocols(session, verbose, photo2sketch=True):
if verbose>=1: print('Creating the protocol arface-cuhk-xm2vts ...')
#getting the files
world_files = arface.get_files_from_group(group="world")
dev_files = cuhk.get_files_from_group(group="dev")
eval_files = xm2vts.get_files_from_group(group="eval")
world_clients = arface.get_clients_from_group(group="world")
dev_clients = cuhk.get_clients_from_group(group="dev")
eval_clients = xm2vts.get_clients_from_group(group="eval")
#Inserting in the database
insert_protocol_data(session, "arface-cuhk-xm2vts"+suffix, "world", "train", world_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "arface-cuhk-xm2vts"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "arface-cuhk-xm2vts"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "arface-cuhk-xm2vts"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "arface-cuhk-xm2vts"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "arface-cuhk-xm2vts"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch)
session.commit()
......@@ -353,15 +352,15 @@ def add_protocols(session, verbose, photo2sketch=True):
if verbose>=1: print('Creating the protocol arface-xm2vts-cuhk ...')
#getting the files
world_files = arface.get_files_from_group(group="world")
dev_files = xm2vts.get_files_from_group(group="dev")
eval_files = cuhk.get_files_from_group(group="eval")
world_clients = arface.get_clients_from_group(group="world")
dev_clients = xm2vts.get_clients_from_group(group="dev")
eval_clients = cuhk.get_clients_from_group(group="eval")
#Inserting in the database
insert_protocol_data(session, "arface-xm2vts-cuhk"+suffix, "world", "train", world_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "arface-xm2vts-cuhk"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "arface-xm2vts-cuhk"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "arface-xm2vts-cuhk"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "arface-xm2vts-cuhk"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "arface-xm2vts-cuhk"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch)
session.commit()
......@@ -371,15 +370,15 @@ def add_protocols(session, verbose, photo2sketch=True):
if verbose>=1: print('Creating the protocol xm2vts-cuhk-arface ...')
#getting the files
world_files = xm2vts.get_files_from_group(group="world")
dev_files = cuhk.get_files_from_group(group="dev")
eval_files = arface.get_files_from_group(group="eval")
world_clients = xm2vts.get_clients_from_group(group="world")
dev_clients = cuhk.get_clients_from_group(group="dev")
eval_clients = arface.get_clients_from_group(group="eval")
#Inserting in the database
insert_protocol_data(session, "xm2vts-cuhk-arface"+suffix, "world", "train", world_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts-cuhk-arface"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts-cuhk-arface"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts-cuhk-arface"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts-cuhk-arface"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts-cuhk-arface"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch)
session.commit()
......@@ -390,41 +389,100 @@ def add_protocols(session, verbose, photo2sketch=True):
if verbose>=1: print('Creating the protocol xm2vts-arface-cuhk ...')
#getting the files
world_files = xm2vts.get_files_from_group(group="world")
dev_files = arface.get_files_from_group(group="dev")
eval_files = cuhk.get_files_from_group(group="eval")
world_clients = xm2vts.get_clients_from_group(group="world")
dev_clients = arface.get_clients_from_group(group="dev")
eval_clients = cuhk.get_clients_from_group(group="eval")
#Inserting in the database
insert_protocol_data(session, "xm2vts-arface-cuhk"+suffix, "world", "train", world_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts-arface-cuhk"+suffix, "dev", "", dev_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts-arface-cuhk"+suffix, "eval", "", eval_files, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts-arface-cuhk"+suffix, "world", "train", world_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts-arface-cuhk"+suffix, "dev", "", dev_clients, photo2sketch=photo2sketch)
insert_protocol_data(session, "xm2vts-arface-cuhk"+suffix, "eval", "", eval_clients, photo2sketch=photo2sketch)
session.commit()
def insert_protocol_data(session, protocol, group, purpose, file_objects, photo2sketch=True):
for f in file_objects:
if purpose!="train":
if photo2sketch:
def insert_protocol_data(session, protocol, group, purpose, clients, photo2sketch=True):
if f.modality=="photo":
purpose = "enroll"
for c in clients:
if purpose=="train":
#Adding files for training
for f in c.files:
session.add(bob.db.cuhk_cufs.Protocol_File_Association(
protocol, group, purpose, f.id, c.id))
else:
#PRobing
for c1 in clients:
#With same clients, you need to define which one is enroll or probe
if c1.id==c.id:
for f in c1.files:
if ((photo2sketch and f.modality=="photo") or
(not photo2sketch and f.modality=="sketch")):
purpose="enroll"
else:
purpose="probe"
session.add(bob.db.cuhk_cufs.Protocol_File_Association(
protocol, group, purpose, f.id, c.id))
else:
purpose = "probe"
#With different clients the task is to define which file should be included
else:
if f.modality=="photo":
purpose = "probe"
else:
purpose = "enroll"
for f in c1.files:
if ((photo2sketch and f.modality=="photo") or
(not photo2sketch and f.modality=="sketch")):
continue #Excluding files
else:
purpose="probe"
session.add(bob.db.cuhk_cufs.Protocol_File_Association(
protocol, group, purpose, f.id, c.id))
"""
for f in file_objects:
if purpose=="train":
session.add(bob.db.cuhk_cufs.Protocol_File_Association(
protocol, group, purpose, f.id, f.id))
else:
#Excluding some files
if ((not photo2sketch and f.modality=="photo") or
(photo2sketch and f.modality=="sketch") )
continue
#probing
for f1 in file_objects:
#Same client
if f1.client_id == f.client_id:
if photo2sketch:
if f.modality=="photo":
purpose = "enroll"
else:
purpose = "probe"
else:
if f.modality=="photo":
purpose = "probe"
else:
purpose = "enroll"
else:#Different client
if ((photo2sketch and f1.modality=="sketch") or
(not photo2sketch and f1.modality=="photo")):
purpose = "probe"
else:
continue
session.add(bob.db.cuhk_cufs.Protocol_File_Association(
protocol, group, purpose, f.id))
session.add(bob.db.cuhk_cufs.Protocol_File_Association(
protocol, group, purpose, f.id, f1.id))
"""
def create_tables(args):
......
No preview for this file type
......@@ -58,13 +58,16 @@ class Protocol_File_Association(Base):
protocol = Column('protocol', Enum(*PROTOCOLS), primary_key=True)
group = Column('group', Enum(*GROUPS), primary_key=True)
purpose = Column('purpose', Enum(*PURPOSES), primary_key=True)
file_id = Column('file_id', Integer, ForeignKey('file.id'), primary_key=True)
def __init__(self, protocol, group, purpose, file_id):
self.protocol = protocol
self.group = group
self.purpose = purpose
self.file_id = file_id
file_id = Column('file_id', Integer, ForeignKey('file.id'), primary_key=True)
client_id = Column('client_id', Integer, ForeignKey('client.id'), primary_key=True)
def __init__(self, protocol, group, purpose, file_id, client_id):
self.protocol = protocol
self.group = group
self.purpose = purpose
self.file_id = file_id
self.client_id = client_id
......
......@@ -55,16 +55,16 @@ def test01_protocols_purposes_groups():
def test02_all_files_protocols():
cuhk = 376
arface = 246
xm2vts = 590
all_mixed = 1212
cuhk_arface_xm2vts = 408
cuhk_xm2vts_arface = 404
arface_cuhk_xm2vts = 378
arface_xm2vts_cuhk = 378
xm2vts_cuhk_arface = 426
xm2vts_arface_cuhk = 430
cuhk = 6648
arface = 3288
xm2vts = 16078
all_mixed = 68924
cuhk_arface_xm2vts = 9800
cuhk_xm2vts_arface = 9542
arface_cuhk_xm2vts = 11290
arface_xm2vts_cuhk = 11226
xm2vts_cuhk_arface = 4988
xm2vts_arface_cuhk = 5182
assert len(bob.db.cuhk_cufs.Database().objects(protocol="cuhk_p2s")) == cuhk
assert len(bob.db.cuhk_cufs.Database().objects(protocol="cuhk_s2p")) == cuhk
......@@ -93,6 +93,9 @@ def test02_all_files_protocols():
assert len(bob.db.cuhk_cufs.Database().objects(protocol="xm2vts-cuhk-arface_p2s")) == xm2vts_cuhk_arface
assert len(bob.db.cuhk_cufs.Database().objects(protocol="xm2vts-cuhk-arface_s2p")) == xm2vts_cuhk_arface
assert len(bob.db.cuhk_cufs.Database().objects(protocol="xm2vts-arface-cuhk_p2s")) == xm2vts_arface_cuhk
assert len(bob.db.cuhk_cufs.Database().objects(protocol="xm2vts-arface-cuhk_s2p")) == xm2vts_arface_cuhk
def test03_world_files_protocols():
......@@ -138,10 +141,10 @@ def test03_world_files_protocols():
def test04_dev_files_protocols():
cuhk = 112
arface = 80
xm2vts = 176
all_mixed = 368
cuhk = 3192
arface = 1640
xm2vts = 7832
all_mixed = 34040
cuhk_arface_xm2vts = arface
cuhk_xm2vts_arface = xm2vts
arface_cuhk_xm2vts = cuhk
......@@ -180,10 +183,10 @@ def test04_dev_files_protocols():
def test05_eval_files_protocols():
cuhk = 114
arface = 78
xm2vts = 178
all_mixed = 370
cuhk = 3306
arface = 1560
xm2vts = 8010
all_mixed = 34410
cuhk_arface_xm2vts = xm2vts
cuhk_xm2vts_arface = arface
arface_cuhk_xm2vts = xm2vts
......
......@@ -66,7 +66,7 @@ class ARFACEWrapper():
return 'man' if client_id[0]=='m' else 'woman'
def get_files_from_group(self, group=""):
def get_clients_from_group(self, group=""):
"""
Get the bob.db.cuhk_cufs.File for a given group (world, dev or eval).
......@@ -78,18 +78,19 @@ class ARFACEWrapper():
"""
arface = bob.db.arface.Database()
cuhk = bob.db.cuhk_cufs.Database()
import sqlalchemy
#Getting the clients from ARFACE
clients = arface.query(bob.db.arface.Client).filter(bob.db.arface.Client.sgroup==group)
original_clients = arface.query(bob.db.arface.Client).filter(bob.db.arface.Client.sgroup==group)
#Getting the correspondent files from bob.db.cuhk_cufs
files = []
for c in clients:
cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_id==c.id)
for f in cuhk_files:
files.append(f)
clients = []
for o in original_clients:
cuhk_clients = cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_id==o.id).options(sqlalchemy.orm.subqueryload(bob.db.cuhk_cufs.Client.files)) #forcing to bring the clients
for c in cuhk_clients:
clients.append(c)
return files
return list(clients)
def get_annotations(self, annotation_dir, annotation_extension='.dat'):
......@@ -212,7 +213,7 @@ class XM2VTSWrapper():
def get_gender(self):
return 'none'
def get_files_from_group(self, group=""):
def get_clients_from_group(self, group=""):
"""
This is a hand made protocol since the XM2VTS database is biased.
......@@ -225,8 +226,9 @@ class XM2VTSWrapper():
indexes = [273, 241, 285, 256, 173, 193, 107, 55, 53, 143, 163, 63, 13, 113, 258, 271, 134, 17, 20, 227, 203, 96, 66, 112, 77, 237, 42, 61, 272, 161, 209, 206, 195, 140, 150, 294, 152, 136, 188, 232, 21, 75, 141, 25, 249, 269, 70, 217, 251, 29, 153, 83, 185, 94, 116, 265, 177, 38, 156, 191, 118, 121, 204, 100, 255, 286, 78, 260, 282, 33, 242, 200, 91, 224, 137, 180, 65, 12, 3, 151, 154, 1, 290, 198, 167, 212, 72, 133, 144, 57, 0, 211, 48, 292, 213, 277, 52, 223, 115, 230, 49, 4, 291, 214, 18, 71, 146, 289, 250, 268, 201, 170, 11, 178, 2, 155, 264, 64, 287, 14, 110, 30, 19, 149, 68, 183, 44, 60, 181, 283, 86, 139, 81, 126, 202, 120, 10, 9, 164, 218, 43, 148, 105, 186, 225, 93, 184, 50, 257, 132, 254, 27, 108, 106, 69, 252, 138, 122, 196, 175, 228, 7, 168, 135, 15, 231, 182, 280, 147, 54, 261, 79, 281, 125, 142, 101, 259, 41, 187, 16, 275, 248, 179, 169, 89, 245, 26, 73, 199, 90, 128, 236, 40, 166, 262, 84, 32, 97, 92, 174, 284, 37, 36, 111, 82, 104, 58, 98, 235, 215, 220, 130, 85, 216, 205, 274, 22, 244, 129, 247, 6, 240, 279, 5, 109, 31, 74, 127, 95, 117, 210, 165, 80, 59, 114, 194, 238, 207, 239, 267, 159, 243, 131, 171, 67, 222, 8, 47, 45, 99, 123, 229, 293, 270, 253, 46, 162, 263, 102, 76, 88, 28, 158, 278, 62, 246, 176, 124, 234, 276, 87, 24, 157, 119, 197, 190, 35, 34, 160, 56, 266, 172, 39, 233, 221, 192, 288, 23, 226, 219, 189, 208, 145, 103, 51]
#Fetching the clients
import sqlalchemy
cuhk = bob.db.cuhk_cufs.Database()
all_clients = numpy.array(cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_database=="xm2vts").order_by(bob.db.cuhk_cufs.Client.original_id).all())
all_clients = numpy.array(cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_database=="xm2vts").order_by(bob.db.cuhk_cufs.Client.original_id).options(sqlalchemy.orm.subqueryload(bob.db.cuhk_cufs.Client.files)).all()) #forcing to bring the clients.
data_training = 118
data_dev = 88
......@@ -245,13 +247,13 @@ class XM2VTSWrapper():
#Fetching the correspondent files from bob.db.cuhk_cufs
files = []
for c in clients:
cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.id==c.id)
for f in cuhk_files:
files.append(f)
#files = []
#for c in clients:
#cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.id==c.id)
#for f in cuhk_files:
#files.append(f)
return files
return list(clients)
def get_files_from_group_biased(self, group=""):
......@@ -459,7 +461,7 @@ class CUHKWrapper():
return annotations
def get_files_from_group(self, group=""):
def get_clients_from_group(self, group=""):
"""
This is a hand made protocol since there is no protocol for the CUHK-CUFS database.
......@@ -467,14 +469,14 @@ class CUHKWrapper():
- 40% for training --> 75
- 30% for developement --> 56
- 30% for testing --> 57
"""
"""
indexes = [152, 70, 150, 120, 181, 64, 16, 66, 154, 1, 84, 35, 179, 105, 49, 159, 128, 14, 103, 157, 18, 148, 88, 134, 147, 72, 62, 110, 20, 27, 30, 187, 50, 117, 83, 71, 81, 61, 185, 85, 2, 145, 138, 45, 129, 151, 96, 132, 146, 87, 156, 173, 73, 38, 125, 69, 82, 34, 116, 102, 136, 91, 7, 143, 109, 112, 115, 63, 33, 165, 104, 170, 76, 36, 114, 5, 142, 90, 60, 40, 93, 67, 180, 77, 106, 130, 135, 124, 118, 6, 39, 97, 121, 4, 74, 86, 57, 24, 65, 167, 184, 163, 47, 169, 94, 8, 58, 126, 166, 15, 172, 11, 89, 162, 42, 98, 22, 133, 78, 175, 0, 160, 92, 37, 161, 17, 26, 122, 137, 164, 99, 149, 32, 95, 144, 46, 155, 168, 48, 182, 23, 80, 10, 140, 9, 55, 29, 113, 12, 54, 158, 52, 41, 119, 183, 25, 131, 107, 176, 31, 111, 108, 123, 79, 153, 178, 139, 51, 13, 177, 141, 171, 101, 3, 43, 68, 56, 21, 75, 28, 53, 44, 19, 174, 100, 127, 186, 59]
#Fetching the clients
cuhk = bob.db.cuhk_cufs.Database()
all_clients = numpy.array(cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_database=="cuhk").order_by(bob.db.cuhk_cufs.Client.id).all())
import sqlalchemy
all_clients = numpy.array(cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_database=="cuhk").order_by(bob.db.cuhk_cufs.Client.id).options(sqlalchemy.orm.subqueryload(bob.db.cuhk_cufs.Client.files)).all())
data_training = 75
data_dev = 56
......@@ -492,13 +494,13 @@ class CUHKWrapper():
clients = all_clients[indexes[offset:offset+data_eval]]
#Fetching the correspondent files from bob.db.cuhk_cufs
files = []
for c in clients:
cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.id==c.id)
for f in cuhk_files:
files.append(f)
#files = []
#for c in clients:
#cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.id==c.id)
#for f in cuhk_files:
#files.append(f)
return files
return list(clients)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment