Fixed several bugs

parent be1360b9
This diff is collapsed.
No preview for this file type
......@@ -58,13 +58,16 @@ class Protocol_File_Association(Base):
protocol = Column('protocol', Enum(*PROTOCOLS), primary_key=True)
group = Column('group', Enum(*GROUPS), primary_key=True)
purpose = Column('purpose', Enum(*PURPOSES), primary_key=True)
file_id = Column('file_id', Integer, ForeignKey('file.id'), primary_key=True)
def __init__(self, protocol, group, purpose, file_id):
self.protocol = protocol
self.group = group
self.purpose = purpose
self.file_id = file_id
file_id = Column('file_id', Integer, ForeignKey('file.id'), primary_key=True)
client_id = Column('client_id', Integer, ForeignKey('client.id'), primary_key=True)
def __init__(self, protocol, group, purpose, file_id, client_id):
self.protocol = protocol
self.group = group
self.purpose = purpose
self.file_id = file_id
self.client_id = client_id
......
......@@ -55,16 +55,16 @@ def test01_protocols_purposes_groups():
def test02_all_files_protocols():
cuhk = 376
arface = 246
xm2vts = 590
all_mixed = 1212
cuhk_arface_xm2vts = 408
cuhk_xm2vts_arface = 404
arface_cuhk_xm2vts = 378
arface_xm2vts_cuhk = 378
xm2vts_cuhk_arface = 426
xm2vts_arface_cuhk = 430
cuhk = 6648
arface = 3288
xm2vts = 16078
all_mixed = 68924
cuhk_arface_xm2vts = 9800
cuhk_xm2vts_arface = 9542
arface_cuhk_xm2vts = 11290
arface_xm2vts_cuhk = 11226
xm2vts_cuhk_arface = 4988
xm2vts_arface_cuhk = 5182
assert len(bob.db.cuhk_cufs.Database().objects(protocol="cuhk_p2s")) == cuhk
assert len(bob.db.cuhk_cufs.Database().objects(protocol="cuhk_s2p")) == cuhk
......@@ -93,6 +93,9 @@ def test02_all_files_protocols():
assert len(bob.db.cuhk_cufs.Database().objects(protocol="xm2vts-cuhk-arface_p2s")) == xm2vts_cuhk_arface
assert len(bob.db.cuhk_cufs.Database().objects(protocol="xm2vts-cuhk-arface_s2p")) == xm2vts_cuhk_arface
assert len(bob.db.cuhk_cufs.Database().objects(protocol="xm2vts-arface-cuhk_p2s")) == xm2vts_arface_cuhk
assert len(bob.db.cuhk_cufs.Database().objects(protocol="xm2vts-arface-cuhk_s2p")) == xm2vts_arface_cuhk
def test03_world_files_protocols():
......@@ -138,10 +141,10 @@ def test03_world_files_protocols():
def test04_dev_files_protocols():
cuhk = 112
arface = 80
xm2vts = 176
all_mixed = 368
cuhk = 3192
arface = 1640
xm2vts = 7832
all_mixed = 34040
cuhk_arface_xm2vts = arface
cuhk_xm2vts_arface = xm2vts
arface_cuhk_xm2vts = cuhk
......@@ -180,10 +183,10 @@ def test04_dev_files_protocols():
def test05_eval_files_protocols():
cuhk = 114
arface = 78
xm2vts = 178
all_mixed = 370
cuhk = 3306
arface = 1560
xm2vts = 8010
all_mixed = 34410
cuhk_arface_xm2vts = xm2vts
cuhk_xm2vts_arface = arface
arface_cuhk_xm2vts = xm2vts
......
......@@ -66,7 +66,7 @@ class ARFACEWrapper():
return 'man' if client_id[0]=='m' else 'woman'
def get_files_from_group(self, group=""):
def get_clients_from_group(self, group=""):
"""
Get the bob.db.cuhk_cufs.File for a given group (world, dev or eval).
......@@ -78,18 +78,19 @@ class ARFACEWrapper():
"""
arface = bob.db.arface.Database()
cuhk = bob.db.cuhk_cufs.Database()
import sqlalchemy
#Getting the clients from ARFACE
clients = arface.query(bob.db.arface.Client).filter(bob.db.arface.Client.sgroup==group)
original_clients = arface.query(bob.db.arface.Client).filter(bob.db.arface.Client.sgroup==group)
#Getting the correspondent files from bob.db.cuhk_cufs
files = []
for c in clients:
cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_id==c.id)
for f in cuhk_files:
files.append(f)
clients = []
for o in original_clients:
cuhk_clients = cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_id==o.id).options(sqlalchemy.orm.subqueryload(bob.db.cuhk_cufs.Client.files)) #forcing to bring the clients
for c in cuhk_clients:
clients.append(c)
return files
return list(clients)
def get_annotations(self, annotation_dir, annotation_extension='.dat'):
......@@ -212,7 +213,7 @@ class XM2VTSWrapper():
def get_gender(self):
return 'none'
def get_files_from_group(self, group=""):
def get_clients_from_group(self, group=""):
"""
This is a hand made protocol since the XM2VTS database is biased.
......@@ -225,8 +226,9 @@ class XM2VTSWrapper():
indexes = [273, 241, 285, 256, 173, 193, 107, 55, 53, 143, 163, 63, 13, 113, 258, 271, 134, 17, 20, 227, 203, 96, 66, 112, 77, 237, 42, 61, 272, 161, 209, 206, 195, 140, 150, 294, 152, 136, 188, 232, 21, 75, 141, 25, 249, 269, 70, 217, 251, 29, 153, 83, 185, 94, 116, 265, 177, 38, 156, 191, 118, 121, 204, 100, 255, 286, 78, 260, 282, 33, 242, 200, 91, 224, 137, 180, 65, 12, 3, 151, 154, 1, 290, 198, 167, 212, 72, 133, 144, 57, 0, 211, 48, 292, 213, 277, 52, 223, 115, 230, 49, 4, 291, 214, 18, 71, 146, 289, 250, 268, 201, 170, 11, 178, 2, 155, 264, 64, 287, 14, 110, 30, 19, 149, 68, 183, 44, 60, 181, 283, 86, 139, 81, 126, 202, 120, 10, 9, 164, 218, 43, 148, 105, 186, 225, 93, 184, 50, 257, 132, 254, 27, 108, 106, 69, 252, 138, 122, 196, 175, 228, 7, 168, 135, 15, 231, 182, 280, 147, 54, 261, 79, 281, 125, 142, 101, 259, 41, 187, 16, 275, 248, 179, 169, 89, 245, 26, 73, 199, 90, 128, 236, 40, 166, 262, 84, 32, 97, 92, 174, 284, 37, 36, 111, 82, 104, 58, 98, 235, 215, 220, 130, 85, 216, 205, 274, 22, 244, 129, 247, 6, 240, 279, 5, 109, 31, 74, 127, 95, 117, 210, 165, 80, 59, 114, 194, 238, 207, 239, 267, 159, 243, 131, 171, 67, 222, 8, 47, 45, 99, 123, 229, 293, 270, 253, 46, 162, 263, 102, 76, 88, 28, 158, 278, 62, 246, 176, 124, 234, 276, 87, 24, 157, 119, 197, 190, 35, 34, 160, 56, 266, 172, 39, 233, 221, 192, 288, 23, 226, 219, 189, 208, 145, 103, 51]
#Fetching the clients
import sqlalchemy
cuhk = bob.db.cuhk_cufs.Database()
all_clients = numpy.array(cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_database=="xm2vts").order_by(bob.db.cuhk_cufs.Client.original_id).all())
all_clients = numpy.array(cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_database=="xm2vts").order_by(bob.db.cuhk_cufs.Client.original_id).options(sqlalchemy.orm.subqueryload(bob.db.cuhk_cufs.Client.files)).all()) #forcing to bring the clients.
data_training = 118
data_dev = 88
......@@ -245,13 +247,13 @@ class XM2VTSWrapper():
#Fetching the correspondent files from bob.db.cuhk_cufs
files = []
for c in clients:
cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.id==c.id)
for f in cuhk_files:
files.append(f)
#files = []
#for c in clients:
#cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.id==c.id)
#for f in cuhk_files:
#files.append(f)
return files
return list(clients)
def get_files_from_group_biased(self, group=""):
......@@ -459,7 +461,7 @@ class CUHKWrapper():
return annotations
def get_files_from_group(self, group=""):
def get_clients_from_group(self, group=""):
"""
This is a hand made protocol since there is no protocol for the CUHK-CUFS database.
......@@ -467,14 +469,14 @@ class CUHKWrapper():
- 40% for training --> 75
- 30% for developement --> 56
- 30% for testing --> 57
"""
"""
indexes = [152, 70, 150, 120, 181, 64, 16, 66, 154, 1, 84, 35, 179, 105, 49, 159, 128, 14, 103, 157, 18, 148, 88, 134, 147, 72, 62, 110, 20, 27, 30, 187, 50, 117, 83, 71, 81, 61, 185, 85, 2, 145, 138, 45, 129, 151, 96, 132, 146, 87, 156, 173, 73, 38, 125, 69, 82, 34, 116, 102, 136, 91, 7, 143, 109, 112, 115, 63, 33, 165, 104, 170, 76, 36, 114, 5, 142, 90, 60, 40, 93, 67, 180, 77, 106, 130, 135, 124, 118, 6, 39, 97, 121, 4, 74, 86, 57, 24, 65, 167, 184, 163, 47, 169, 94, 8, 58, 126, 166, 15, 172, 11, 89, 162, 42, 98, 22, 133, 78, 175, 0, 160, 92, 37, 161, 17, 26, 122, 137, 164, 99, 149, 32, 95, 144, 46, 155, 168, 48, 182, 23, 80, 10, 140, 9, 55, 29, 113, 12, 54, 158, 52, 41, 119, 183, 25, 131, 107, 176, 31, 111, 108, 123, 79, 153, 178, 139, 51, 13, 177, 141, 171, 101, 3, 43, 68, 56, 21, 75, 28, 53, 44, 19, 174, 100, 127, 186, 59]
#Fetching the clients
cuhk = bob.db.cuhk_cufs.Database()
all_clients = numpy.array(cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_database=="cuhk").order_by(bob.db.cuhk_cufs.Client.id).all())
import sqlalchemy
all_clients = numpy.array(cuhk.query(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.original_database=="cuhk").order_by(bob.db.cuhk_cufs.Client.id).options(sqlalchemy.orm.subqueryload(bob.db.cuhk_cufs.Client.files)).all())
data_training = 75
data_dev = 56
......@@ -492,13 +494,13 @@ class CUHKWrapper():
clients = all_clients[indexes[offset:offset+data_eval]]
#Fetching the correspondent files from bob.db.cuhk_cufs
files = []
for c in clients:
cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.id==c.id)
for f in cuhk_files:
files.append(f)
#files = []
#for c in clients:
#cuhk_files = cuhk.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Client).filter(bob.db.cuhk_cufs.Client.id==c.id)
#for f in cuhk_files:
#files.append(f)
return files
return list(clients)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment