Commit c48db237 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Finished unit tests

parent ed9a02b1
......@@ -208,15 +208,6 @@ def add_protocols(session, verbose, photo2sketch=True):
"""
PROTOCOLS = ('cuhk_p2s', 'arface_p2s', 'xm2vts_p2s', 'all-mixed_p2s', 'cuhk-arface-xm2vts_p2s', 'cuhk-xm2vts-arface_p2s',
'arface-cuhk-xm2vts_p2s', 'arface-xm2vts-cuhk_p2s', 'xm2vts-cuhk-arface_p2s', 'xm2vts-arface-cuhk_p2s',
'cuhk_s2p', 'arface_s2p', 'xm2vts_s2p', 'all-mixed_s2p', 'cuhk-arface-xm2vts_s2p', 'cuhk-xm2vts-arface_s2p',
'arface-cuhk-xm2vts_s2p', 'arface-xm2vts-cuhk_s2p', 'xm2vts-cuhk-arface_s2p', 'xm2vts-arface-cuhk_s2p')
GROUPS = ('world', 'dev', 'eval')
PURPOSES = ('train', 'enrol', 'probe')
arface = ARFACEWrapper()
xm2vts = XM2VTSWrapper()
cuhk = CUHKWrapper()
......@@ -417,11 +408,20 @@ def insert_protocol_data(session, protocol, group, purpose, file_objects, photo2
for f in file_objects:
if purpose!="train":
if photo2sketch and f.modality=="photo":
purpose = "enrol"
if photo2sketch:
if f.modality=="photo":
purpose = "enroll"
else:
purpose = "probe"
else:
purpose = "probe"
if f.modality=="photo":
purpose = "probe"
else:
purpose = "enroll"
session.add(bob.db.cuhk.Protocol_File_Association(
protocol, group, purpose, f.id))
......
......@@ -46,7 +46,7 @@ PROTOCOLS = ('cuhk_p2s', 'arface_p2s', 'xm2vts_p2s', 'all-mixed_p2s', 'cuhk-arfa
GROUPS = ('world', 'dev', 'eval')
PURPOSES = ('train', 'enrol', 'probe')
PURPOSES = ('train', 'enroll', 'probe')
class Protocol_File_Association(Base):
......@@ -94,6 +94,36 @@ class Client(Base):
return "<Client({0}, {1}, {2})>".format(self.id, self.original_database, self.original_id)
class Annotation(Base):
"""
The CUHK-CUFS provides 35 coordinates.
To model this coordinates this table was built.
The columns are the following:
- Annotation.id
- x
- y
"""
__tablename__ = 'annotation'
file_id = Column(Integer, ForeignKey('file.id'), primary_key=True)
x = Column(Integer, primary_key=True)
y = Column(Integer, primary_key=True)
index = Column(Integer)
def __init__(self, file_id, x,y, index=0):
self.file_id = file_id
self.x = x
self.y = y
self.index = index
def __repr__(self):
return "<Annotation(file_id:{0}, index:{1}, y={2}, x={3})>".format(self.file_id, self.index, self.y, self.x)
class File(Base, bob.db.verification.utils.File):
"""
Information about the files of the CUHK-CUFS database.
......@@ -112,7 +142,7 @@ class File(Base, bob.db.verification.utils.File):
# a back-reference from the client class to a list of files
client = relationship("Client", backref=backref("files", order_by=id))
all_annotations = relationship("Annotation", backref=backref("file"), uselist=True)
all_annotations = relationship("Annotation", backref=backref("file"), uselist=True, order_by=Annotation.index)
def __init__(self, id, image_name, client_id, modality):
# call base class constructor
......@@ -133,26 +163,5 @@ class File(Base, bob.db.verification.utils.File):
class Annotation(Base):
"""
The CUHK-CUFS provides 35 coordinates.
To model this coordinates this table was built.
The columns are the following:
- Annotation.id
- x
- y
"""
__tablename__ = 'annotation'
file_id = Column(Integer, ForeignKey('file.id'), primary_key=True)
x = Column(Integer, primary_key=True)
y = Column(Integer, primary_key=True)
def __init__(self, file_id, x,y):
self.file_id = file_id
self.x = x
self.y = y
......@@ -40,6 +40,13 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
bob.db.verification.utils.SQLiteDatabase.__init__(self, SQLITE_FILE, File)
bob.db.verification.utils.ZTDatabase.__init__(self, original_directory=original_directory, original_extension=original_extension)
def protocols(self):
return PROTOCOLS
def purposes(self):
return PURPOSES
def objects(self, groups = None, protocol = None, purposes = None, model_ids = None, **kwargs):
"""
......@@ -57,7 +64,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
raise ValueError("Please, select only one of the following protocols {0}".format(protocols))
#Querying
query = self.query(bob.db.cuhk.File).join(bob.db.cuhk.Protocol_File_Association).join(bob.db.cuhk.Client)
query = self.query(bob.db.cuhk.File, bob.db.cuhk.Protocol_File_Association).join(bob.db.cuhk.Protocol_File_Association).join(bob.db.cuhk.Client)
#filtering
query = query.filter(bob.db.cuhk.Protocol_File_Association.group.in_(groups))
......@@ -70,7 +77,15 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
query = query.filter(bob.db.cuhk.Client.id.in_(model_ids))
return query.all()
raw_files = query.all()
files = []
for f in raw_files:
f[0].group = f[1].group
f[0].purpose = f[1].purpose
f[0].protocol = f[1].protocol
files.append(f[0])
return files
def model_ids(self, protocol=None, groups=None):
......
......@@ -21,14 +21,320 @@
"""
import bob.db.cuhk
possible_protocols = ["cuhk"]
#possible_protocols = ["cuhk"]
""" Defining protocols. Yes, they are static """
PROTOCOLS = ('cuhk_p2s', 'arface_p2s', 'xm2vts_p2s', 'all-mixed_p2s', 'cuhk-arface-xm2vts_p2s', 'cuhk-xm2vts-arface_p2s',
'arface-cuhk-xm2vts_p2s', 'arface-xm2vts-cuhk_p2s', 'xm2vts-cuhk-arface_p2s', 'xm2vts-arface-cuhk_p2s',
'cuhk_s2p', 'arface_s2p', 'xm2vts_s2p', 'all-mixed_s2p', 'cuhk-arface-xm2vts_s2p', 'cuhk-xm2vts-arface_s2p',
'arface-cuhk-xm2vts_s2p', 'arface-xm2vts-cuhk_s2p', 'xm2vts-cuhk-arface_s2p', 'xm2vts-arface-cuhk_s2p')
GROUPS = ('world', 'dev', 'eval')
PURPOSES = ('train', 'enroll', 'probe')
def test_protocols():
import os
db = bob.db.cuhk.Database()
available_protocols = os.listdir(db.get_base_directory())
def test01_protocols_purposes_groups():
#testing protocols
possible_protocols = bob.db.cuhk.Database().protocols()
for p in possible_protocols:
assert p in available_protocols
assert p in PROTOCOLS
#testing purposes
possible_purposes = bob.db.cuhk.Database().purposes()
for p in possible_purposes:
assert p in PURPOSES
#testing GROUPS
possible_groups = bob.db.cuhk.Database().groups()
for p in possible_groups:
assert p in GROUPS
def test02_all_files_protocols():
cuhk = 376
arface = 246
xm2vts = 590
all_mixed = 1212
cuhk_arface_xm2vts = 408
cuhk_xm2vts_arface = 404
arface_cuhk_xm2vts = 378
arface_xm2vts_cuhk = 378
xm2vts_cuhk_arface = 426
xm2vts_arface_cuhk = 430
assert len(bob.db.cuhk.Database().objects(protocol="cuhk_p2s")) == cuhk
assert len(bob.db.cuhk.Database().objects(protocol="cuhk_s2p")) == cuhk
assert len(bob.db.cuhk.Database().objects(protocol="arface_p2s")) == arface
assert len(bob.db.cuhk.Database().objects(protocol="arface_s2p")) == arface
assert len(bob.db.cuhk.Database().objects(protocol="xm2vts_p2s")) == xm2vts
assert len(bob.db.cuhk.Database().objects(protocol="xm2vts_s2p")) == xm2vts
assert len(bob.db.cuhk.Database().objects(protocol="all-mixed_p2s")) == all_mixed
assert len(bob.db.cuhk.Database().objects(protocol="all-mixed_s2p")) == all_mixed
assert len(bob.db.cuhk.Database().objects(protocol="cuhk-arface-xm2vts_p2s")) == cuhk_arface_xm2vts
assert len(bob.db.cuhk.Database().objects(protocol="cuhk-arface-xm2vts_s2p")) == cuhk_arface_xm2vts
assert len(bob.db.cuhk.Database().objects(protocol="cuhk-xm2vts-arface_p2s")) == cuhk_xm2vts_arface
assert len(bob.db.cuhk.Database().objects(protocol="cuhk-xm2vts-arface_s2p")) == cuhk_xm2vts_arface
assert len(bob.db.cuhk.Database().objects(protocol="arface-xm2vts-cuhk_p2s")) == arface_xm2vts_cuhk
assert len(bob.db.cuhk.Database().objects(protocol="arface-xm2vts-cuhk_s2p")) == arface_xm2vts_cuhk
assert len(bob.db.cuhk.Database().objects(protocol="arface-cuhk-xm2vts_p2s")) == arface_cuhk_xm2vts
assert len(bob.db.cuhk.Database().objects(protocol="arface-cuhk-xm2vts_s2p")) == arface_cuhk_xm2vts
assert len(bob.db.cuhk.Database().objects(protocol="xm2vts-cuhk-arface_p2s")) == xm2vts_cuhk_arface
assert len(bob.db.cuhk.Database().objects(protocol="xm2vts-cuhk-arface_s2p")) == xm2vts_cuhk_arface
def test03_world_files_protocols():
cuhk = 150
arface = 88
xm2vts = 236
all_mixed = 474
cuhk_arface_xm2vts = cuhk
cuhk_xm2vts_arface = cuhk
arface_cuhk_xm2vts = arface
arface_xm2vts_cuhk = arface
xm2vts_cuhk_arface = xm2vts
xm2vts_arface_cuhk = xm2vts
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="cuhk_p2s")) == cuhk
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="cuhk_s2p")) == cuhk
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="arface_p2s")) == arface
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="arface_s2p")) == arface
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="xm2vts_p2s")) == xm2vts
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="xm2vts_s2p")) == xm2vts
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="all-mixed_p2s")) == all_mixed
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="all-mixed_s2p")) == all_mixed
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="cuhk-arface-xm2vts_p2s")) == cuhk_arface_xm2vts
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="cuhk-arface-xm2vts_s2p")) == cuhk_arface_xm2vts
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="cuhk-xm2vts-arface_p2s")) == cuhk_xm2vts_arface
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="cuhk-xm2vts-arface_s2p")) == cuhk_xm2vts_arface
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="arface-xm2vts-cuhk_p2s")) == arface_xm2vts_cuhk
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="arface-xm2vts-cuhk_s2p")) == arface_xm2vts_cuhk
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="arface-cuhk-xm2vts_p2s")) == arface_cuhk_xm2vts
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="arface-cuhk-xm2vts_s2p")) == arface_cuhk_xm2vts
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="xm2vts-cuhk-arface_p2s")) == xm2vts_cuhk_arface
assert len(bob.db.cuhk.Database().objects(groups='world', protocol="xm2vts-cuhk-arface_s2p")) == xm2vts_cuhk_arface
def test04_dev_files_protocols():
cuhk = 112
arface = 80
xm2vts = 176
all_mixed = 368
cuhk_arface_xm2vts = arface
cuhk_xm2vts_arface = xm2vts
arface_cuhk_xm2vts = cuhk
arface_xm2vts_cuhk = xm2vts
xm2vts_cuhk_arface = cuhk
xm2vts_arface_cuhk = arface
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="cuhk_p2s")) == cuhk
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="cuhk_s2p")) == cuhk
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="arface_p2s")) == arface
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="arface_s2p")) == arface
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="xm2vts_p2s")) == xm2vts
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="xm2vts_s2p")) == xm2vts
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="all-mixed_p2s")) == all_mixed
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="all-mixed_s2p")) == all_mixed
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="cuhk-arface-xm2vts_p2s")) == cuhk_arface_xm2vts
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="cuhk-arface-xm2vts_s2p")) == cuhk_arface_xm2vts
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="cuhk-xm2vts-arface_p2s")) == cuhk_xm2vts_arface
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="cuhk-xm2vts-arface_s2p")) == cuhk_xm2vts_arface
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="arface-xm2vts-cuhk_p2s")) == arface_xm2vts_cuhk
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="arface-xm2vts-cuhk_s2p")) == arface_xm2vts_cuhk
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="arface-cuhk-xm2vts_p2s")) == arface_cuhk_xm2vts
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="arface-cuhk-xm2vts_s2p")) == arface_cuhk_xm2vts
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="xm2vts-cuhk-arface_p2s")) == xm2vts_cuhk_arface
assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="xm2vts-cuhk-arface_s2p")) == xm2vts_cuhk_arface
def test05_eval_files_protocols():
cuhk = 114
arface = 78
xm2vts = 178
all_mixed = 370
cuhk_arface_xm2vts = xm2vts
cuhk_xm2vts_arface = arface
arface_cuhk_xm2vts = xm2vts
arface_xm2vts_cuhk = cuhk
xm2vts_cuhk_arface = arface
xm2vts_arface_cuhk = cuhk
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="cuhk_p2s")) == cuhk
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="cuhk_s2p")) == cuhk
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="arface_p2s")) == arface
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="arface_s2p")) == arface
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="xm2vts_p2s")) == xm2vts
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="xm2vts_s2p")) == xm2vts
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="all-mixed_p2s")) == all_mixed
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="all-mixed_s2p")) == all_mixed
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="cuhk-arface-xm2vts_p2s")) == cuhk_arface_xm2vts
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="cuhk-arface-xm2vts_s2p")) == cuhk_arface_xm2vts
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="cuhk-xm2vts-arface_p2s")) == cuhk_xm2vts_arface
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="cuhk-xm2vts-arface_s2p")) == cuhk_xm2vts_arface
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="arface-xm2vts-cuhk_p2s")) == arface_xm2vts_cuhk
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="arface-xm2vts-cuhk_s2p")) == arface_xm2vts_cuhk
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="arface-cuhk-xm2vts_p2s")) == arface_cuhk_xm2vts
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="arface-cuhk-xm2vts_s2p")) == arface_cuhk_xm2vts
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="xm2vts-cuhk-arface_p2s")) == xm2vts_cuhk_arface
assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="xm2vts-cuhk-arface_s2p")) == xm2vts_cuhk_arface
def test06_dev_enrol_files_protocols():
cuhk = 56
arface = 40
xm2vts = 88
all_mixed = 184
cuhk_arface_xm2vts = arface
cuhk_xm2vts_arface = xm2vts
arface_cuhk_xm2vts = cuhk
arface_xm2vts_cuhk = xm2vts
xm2vts_cuhk_arface = cuhk
xm2vts_arface_cuhk = arface
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="cuhk_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk_p2s"))== cuhk
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="cuhk_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk_s2p")) == cuhk
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="arface_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface_p2s")) == arface
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="arface_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface_s2p")) == arface
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="xm2vts_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts_p2s")) == xm2vts
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="xm2vts_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts_s2p")) == xm2vts
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="all-mixed_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="all-mixed_p2s")) == all_mixed
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="all-mixed_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="all-mixed_s2p")) == all_mixed
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="cuhk-arface-xm2vts_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-arface-xm2vts_p2s")) == cuhk_arface_xm2vts
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="cuhk-arface-xm2vts_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-arface-xm2vts_s2p")) == cuhk_arface_xm2vts
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="cuhk-xm2vts-arface_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-xm2vts-arface_p2s")) == cuhk_xm2vts_arface
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="cuhk-xm2vts-arface_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-xm2vts-arface_s2p")) == cuhk_xm2vts_arface
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="arface-xm2vts-cuhk_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-xm2vts-cuhk_p2s")) == arface_xm2vts_cuhk
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="arface-xm2vts-cuhk_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-xm2vts-cuhk_s2p")) == arface_xm2vts_cuhk
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="arface-cuhk-xm2vts_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-cuhk-xm2vts_p2s")) == arface_cuhk_xm2vts
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="arface-cuhk-xm2vts_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-cuhk-xm2vts_s2p")) == arface_cuhk_xm2vts
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="xm2vts-cuhk-arface_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts-cuhk-arface_p2s")) == xm2vts_cuhk_arface
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="xm2vts-cuhk-arface_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts-cuhk-arface_s2p")) == xm2vts_cuhk_arface
def test07_eval_enrol_files_protocols():
cuhk = 57
arface = 39
xm2vts = 89
all_mixed = 185
cuhk_arface_xm2vts = xm2vts
cuhk_xm2vts_arface = arface
arface_cuhk_xm2vts = xm2vts
arface_xm2vts_cuhk = cuhk
xm2vts_cuhk_arface = arface
xm2vts_arface_cuhk = cuhk
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="cuhk_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk_p2s", groups='eval'))== cuhk
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="cuhk_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk_s2p", groups='eval')) == cuhk
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="arface_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface_p2s", groups='eval')) == arface
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="arface_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface_s2p", groups='eval')) == arface
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="xm2vts_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts_p2s", groups='eval')) == xm2vts
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="xm2vts_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts_s2p", groups='eval')) == xm2vts
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="all-mixed_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="all-mixed_p2s", groups='eval')) == all_mixed
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="all-mixed_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="all-mixed_s2p", groups='eval')) == all_mixed
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="cuhk-arface-xm2vts_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-arface-xm2vts_p2s", groups='eval')) == cuhk_arface_xm2vts
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="cuhk-arface-xm2vts_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-arface-xm2vts_s2p", groups='eval')) == cuhk_arface_xm2vts
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="cuhk-xm2vts-arface_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-xm2vts-arface_p2s", groups='eval')) == cuhk_xm2vts_arface
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="cuhk-xm2vts-arface_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-xm2vts-arface_s2p", groups='eval')) == cuhk_xm2vts_arface
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="arface-xm2vts-cuhk_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-xm2vts-cuhk_p2s", groups='eval')) == arface_xm2vts_cuhk
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="arface-xm2vts-cuhk_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-xm2vts-cuhk_s2p", groups='eval')) == arface_xm2vts_cuhk
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="arface-cuhk-xm2vts_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-cuhk-xm2vts_p2s", groups='eval')) == arface_cuhk_xm2vts
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="arface-cuhk-xm2vts_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-cuhk-xm2vts_s2p", groups='eval')) == arface_cuhk_xm2vts
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="xm2vts-cuhk-arface_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts-cuhk-arface_p2s", groups='eval')) == xm2vts_cuhk_arface
assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="xm2vts-cuhk-arface_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts-cuhk-arface_s2p", groups='eval')) == xm2vts_cuhk_arface
def test08_strings():
db = bob.db.cuhk.Database()
for p in PROTOCOLS:
for g in GROUPS:
for u in PURPOSES:
files = db.objects(purposes=u, groups=g, protocol=p)
for f in files:
#Checking if the strings are correct
assert f.purpose == u
assert f.protocol == p
assert f.group == g
def test09_annotations():
db = bob.db.cuhk.Database()
for p in PROTOCOLS:
for f in db.objects(protocol=p):
assert len(f.annotations(annotation_type=""))==35 #ALL ANNOTATIONS
assert f.annotations()["reye"][0] > 0
assert f.annotations()["reye"][1] > 0
assert f.annotations()["leye"][0] > 0
assert f.annotations()["leye"][1] > 0
......@@ -109,12 +109,15 @@ class ARFACEWrapper():
#Reading the annotation file
original_annotations = read_annotations(path)
index = 0
for a in original_annotations:
annotations.append(bob.db.cuhk.Annotation(o.id,
a[0],
a[1]
a[1],
index = index
))
index += 1
return annotations
......@@ -292,7 +295,7 @@ class XM2VTSWrapper():
files = []
for c in clients:
cuhk_files = cuhk.query(bob.db.cuhk.File).join(bob.db.cuhk.Client).filter(bob.db.cuhk.Client.original_id==c.id)
print "{0} = {1}".format(c.id, cuhk_files.count())
#print "{0} = {1}".format(c.id, cuhk_files.count())
for f in cuhk_files:
files.append(f)
......@@ -314,18 +317,20 @@ class XM2VTSWrapper():
if(o.modality=="sketch"):
path = os.path.join(annotation_dir, o.path) + annotation_extension
else:
#import ipdb; ipdb.set_trace();
file_name = o.path.split("/")[2] #THE ORIGINAL XM2VTS RELATIVE PATH IS: XXX\XXX\XXX
path = os.path.join(annotation_dir,"xm2vts", "photo", file_name) + "_f02" + annotation_extension #FOR SOME REASON THE AUTHORS SET THIS '_f02 IN THE END OF THE FILE'
#Reading the annotation file
original_annotations = read_annotations(path)
index = 0
for a in original_annotations:
annotations.append(bob.db.cuhk.Annotation(o.id,
a[0],
a[1]
a[1],
index = index
))
index += 1
return annotations
......@@ -442,11 +447,15 @@ class CUHKWrapper():
#Reading the annotation file
original_annotations = read_annotations(path)
index = 0
for a in original_annotations:
annotations.append(bob.db.cuhk.Annotation(o.id,
a[0],
a[1]
a[1],
index = index
))
index += 1
return annotations
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment