diff --git a/bob/db/cuhk/create.py b/bob/db/cuhk/create.py index 1f99c4fe0574d3913c6315beac394382b80ef15d..e8c0c06eac49e41f5c4419dee894c81be76b3e7c 100644 --- a/bob/db/cuhk/create.py +++ b/bob/db/cuhk/create.py @@ -208,15 +208,6 @@ def add_protocols(session, verbose, photo2sketch=True): """ - PROTOCOLS = ('cuhk_p2s', 'arface_p2s', 'xm2vts_p2s', 'all-mixed_p2s', 'cuhk-arface-xm2vts_p2s', 'cuhk-xm2vts-arface_p2s', - 'arface-cuhk-xm2vts_p2s', 'arface-xm2vts-cuhk_p2s', 'xm2vts-cuhk-arface_p2s', 'xm2vts-arface-cuhk_p2s', - 'cuhk_s2p', 'arface_s2p', 'xm2vts_s2p', 'all-mixed_s2p', 'cuhk-arface-xm2vts_s2p', 'cuhk-xm2vts-arface_s2p', - 'arface-cuhk-xm2vts_s2p', 'arface-xm2vts-cuhk_s2p', 'xm2vts-cuhk-arface_s2p', 'xm2vts-arface-cuhk_s2p') - - GROUPS = ('world', 'dev', 'eval') - - PURPOSES = ('train', 'enrol', 'probe') - arface = ARFACEWrapper() xm2vts = XM2VTSWrapper() cuhk = CUHKWrapper() @@ -417,11 +408,20 @@ def insert_protocol_data(session, protocol, group, purpose, file_objects, photo2 for f in file_objects: if purpose!="train": - if photo2sketch and f.modality=="photo": - purpose = "enrol" + if photo2sketch: + + if f.modality=="photo": + purpose = "enroll" + else: + purpose = "probe" + else: - purpose = "probe" - + if f.modality=="photo": + purpose = "probe" + else: + purpose = "enroll" + + session.add(bob.db.cuhk.Protocol_File_Association( protocol, group, purpose, f.id)) diff --git a/bob/db/cuhk/models.py b/bob/db/cuhk/models.py index 0cc0ad0b9e11360c3e3e87f2f88d2dc79cd3c1a8..e21bbcb6997048b07233a60ff19617dd7b016df1 100644 --- a/bob/db/cuhk/models.py +++ b/bob/db/cuhk/models.py @@ -46,7 +46,7 @@ PROTOCOLS = ('cuhk_p2s', 'arface_p2s', 'xm2vts_p2s', 'all-mixed_p2s', 'cuhk-arfa GROUPS = ('world', 'dev', 'eval') -PURPOSES = ('train', 'enrol', 'probe') +PURPOSES = ('train', 'enroll', 'probe') class Protocol_File_Association(Base): @@ -94,6 +94,36 @@ class Client(Base): return "<Client({0}, {1}, {2})>".format(self.id, self.original_database, self.original_id) +class Annotation(Base): + """ + The CUHK-CUFS provides 35 coordinates. + To model this coordinates this table was built. + The columns are the following: + + - Annotation.id + - x + - y + + """ + __tablename__ = 'annotation' + + file_id = Column(Integer, ForeignKey('file.id'), primary_key=True) + x = Column(Integer, primary_key=True) + y = Column(Integer, primary_key=True) + index = Column(Integer) + + def __init__(self, file_id, x,y, index=0): + self.file_id = file_id + self.x = x + self.y = y + self.index = index + + + def __repr__(self): + return "<Annotation(file_id:{0}, index:{1}, y={2}, x={3})>".format(self.file_id, self.index, self.y, self.x) + + + class File(Base, bob.db.verification.utils.File): """ Information about the files of the CUHK-CUFS database. @@ -112,7 +142,7 @@ class File(Base, bob.db.verification.utils.File): # a back-reference from the client class to a list of files client = relationship("Client", backref=backref("files", order_by=id)) - all_annotations = relationship("Annotation", backref=backref("file"), uselist=True) + all_annotations = relationship("Annotation", backref=backref("file"), uselist=True, order_by=Annotation.index) def __init__(self, id, image_name, client_id, modality): # call base class constructor @@ -133,26 +163,5 @@ class File(Base, bob.db.verification.utils.File): -class Annotation(Base): - """ - The CUHK-CUFS provides 35 coordinates. - To model this coordinates this table was built. - The columns are the following: - - - Annotation.id - - x - - y - - """ - __tablename__ = 'annotation' - - file_id = Column(Integer, ForeignKey('file.id'), primary_key=True) - x = Column(Integer, primary_key=True) - y = Column(Integer, primary_key=True) - - def __init__(self, file_id, x,y): - self.file_id = file_id - self.x = x - self.y = y diff --git a/bob/db/cuhk/query.py b/bob/db/cuhk/query.py index 6b2ac236666d4fa5af9fd4bd6d0d630f997b048f..f88a4f733bfdbd43b3493e4a715726f501b3336d 100644 --- a/bob/db/cuhk/query.py +++ b/bob/db/cuhk/query.py @@ -40,6 +40,13 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti bob.db.verification.utils.SQLiteDatabase.__init__(self, SQLITE_FILE, File) bob.db.verification.utils.ZTDatabase.__init__(self, original_directory=original_directory, original_extension=original_extension) + + def protocols(self): + return PROTOCOLS + + def purposes(self): + return PURPOSES + def objects(self, groups = None, protocol = None, purposes = None, model_ids = None, **kwargs): """ @@ -57,7 +64,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti raise ValueError("Please, select only one of the following protocols {0}".format(protocols)) #Querying - query = self.query(bob.db.cuhk.File).join(bob.db.cuhk.Protocol_File_Association).join(bob.db.cuhk.Client) + query = self.query(bob.db.cuhk.File, bob.db.cuhk.Protocol_File_Association).join(bob.db.cuhk.Protocol_File_Association).join(bob.db.cuhk.Client) #filtering query = query.filter(bob.db.cuhk.Protocol_File_Association.group.in_(groups)) @@ -70,7 +77,15 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti query = query.filter(bob.db.cuhk.Client.id.in_(model_ids)) - return query.all() + raw_files = query.all() + files = [] + for f in raw_files: + f[0].group = f[1].group + f[0].purpose = f[1].purpose + f[0].protocol = f[1].protocol + files.append(f[0]) + + return files def model_ids(self, protocol=None, groups=None): diff --git a/bob/db/cuhk/test.py b/bob/db/cuhk/test.py index 7f7d1eedb4671addf93c86980d84872aff23326c..1037a10d8d6a6b811b7a60b4eb87cb5eee96cf99 100644 --- a/bob/db/cuhk/test.py +++ b/bob/db/cuhk/test.py @@ -21,14 +21,320 @@ """ import bob.db.cuhk -possible_protocols = ["cuhk"] +#possible_protocols = ["cuhk"] +""" Defining protocols. Yes, they are static """ +PROTOCOLS = ('cuhk_p2s', 'arface_p2s', 'xm2vts_p2s', 'all-mixed_p2s', 'cuhk-arface-xm2vts_p2s', 'cuhk-xm2vts-arface_p2s', + 'arface-cuhk-xm2vts_p2s', 'arface-xm2vts-cuhk_p2s', 'xm2vts-cuhk-arface_p2s', 'xm2vts-arface-cuhk_p2s', + 'cuhk_s2p', 'arface_s2p', 'xm2vts_s2p', 'all-mixed_s2p', 'cuhk-arface-xm2vts_s2p', 'cuhk-xm2vts-arface_s2p', + 'arface-cuhk-xm2vts_s2p', 'arface-xm2vts-cuhk_s2p', 'xm2vts-cuhk-arface_s2p', 'xm2vts-arface-cuhk_s2p') + +GROUPS = ('world', 'dev', 'eval') + +PURPOSES = ('train', 'enroll', 'probe') -def test_protocols(): - import os - db = bob.db.cuhk.Database() - available_protocols = os.listdir(db.get_base_directory()) +def test01_protocols_purposes_groups(): + + #testing protocols + possible_protocols = bob.db.cuhk.Database().protocols() for p in possible_protocols: - assert p in available_protocols + assert p in PROTOCOLS + + #testing purposes + possible_purposes = bob.db.cuhk.Database().purposes() + for p in possible_purposes: + assert p in PURPOSES + + #testing GROUPS + possible_groups = bob.db.cuhk.Database().groups() + for p in possible_groups: + assert p in GROUPS + + +def test02_all_files_protocols(): + + cuhk = 376 + arface = 246 + xm2vts = 590 + all_mixed = 1212 + cuhk_arface_xm2vts = 408 + cuhk_xm2vts_arface = 404 + arface_cuhk_xm2vts = 378 + arface_xm2vts_cuhk = 378 + xm2vts_cuhk_arface = 426 + xm2vts_arface_cuhk = 430 + + assert len(bob.db.cuhk.Database().objects(protocol="cuhk_p2s")) == cuhk + assert len(bob.db.cuhk.Database().objects(protocol="cuhk_s2p")) == cuhk + + assert len(bob.db.cuhk.Database().objects(protocol="arface_p2s")) == arface + assert len(bob.db.cuhk.Database().objects(protocol="arface_s2p")) == arface + + assert len(bob.db.cuhk.Database().objects(protocol="xm2vts_p2s")) == xm2vts + assert len(bob.db.cuhk.Database().objects(protocol="xm2vts_s2p")) == xm2vts + + assert len(bob.db.cuhk.Database().objects(protocol="all-mixed_p2s")) == all_mixed + assert len(bob.db.cuhk.Database().objects(protocol="all-mixed_s2p")) == all_mixed + + assert len(bob.db.cuhk.Database().objects(protocol="cuhk-arface-xm2vts_p2s")) == cuhk_arface_xm2vts + assert len(bob.db.cuhk.Database().objects(protocol="cuhk-arface-xm2vts_s2p")) == cuhk_arface_xm2vts + + assert len(bob.db.cuhk.Database().objects(protocol="cuhk-xm2vts-arface_p2s")) == cuhk_xm2vts_arface + assert len(bob.db.cuhk.Database().objects(protocol="cuhk-xm2vts-arface_s2p")) == cuhk_xm2vts_arface + + assert len(bob.db.cuhk.Database().objects(protocol="arface-xm2vts-cuhk_p2s")) == arface_xm2vts_cuhk + assert len(bob.db.cuhk.Database().objects(protocol="arface-xm2vts-cuhk_s2p")) == arface_xm2vts_cuhk + + assert len(bob.db.cuhk.Database().objects(protocol="arface-cuhk-xm2vts_p2s")) == arface_cuhk_xm2vts + assert len(bob.db.cuhk.Database().objects(protocol="arface-cuhk-xm2vts_s2p")) == arface_cuhk_xm2vts + + assert len(bob.db.cuhk.Database().objects(protocol="xm2vts-cuhk-arface_p2s")) == xm2vts_cuhk_arface + assert len(bob.db.cuhk.Database().objects(protocol="xm2vts-cuhk-arface_s2p")) == xm2vts_cuhk_arface + + +def test03_world_files_protocols(): + + cuhk = 150 + arface = 88 + xm2vts = 236 + all_mixed = 474 + cuhk_arface_xm2vts = cuhk + cuhk_xm2vts_arface = cuhk + arface_cuhk_xm2vts = arface + arface_xm2vts_cuhk = arface + xm2vts_cuhk_arface = xm2vts + xm2vts_arface_cuhk = xm2vts + + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="cuhk_p2s")) == cuhk + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="cuhk_s2p")) == cuhk + + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="arface_p2s")) == arface + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="arface_s2p")) == arface + + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="xm2vts_p2s")) == xm2vts + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="xm2vts_s2p")) == xm2vts + + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="all-mixed_p2s")) == all_mixed + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="all-mixed_s2p")) == all_mixed + + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="cuhk-arface-xm2vts_p2s")) == cuhk_arface_xm2vts + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="cuhk-arface-xm2vts_s2p")) == cuhk_arface_xm2vts + + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="cuhk-xm2vts-arface_p2s")) == cuhk_xm2vts_arface + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="cuhk-xm2vts-arface_s2p")) == cuhk_xm2vts_arface + + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="arface-xm2vts-cuhk_p2s")) == arface_xm2vts_cuhk + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="arface-xm2vts-cuhk_s2p")) == arface_xm2vts_cuhk + + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="arface-cuhk-xm2vts_p2s")) == arface_cuhk_xm2vts + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="arface-cuhk-xm2vts_s2p")) == arface_cuhk_xm2vts + + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="xm2vts-cuhk-arface_p2s")) == xm2vts_cuhk_arface + assert len(bob.db.cuhk.Database().objects(groups='world', protocol="xm2vts-cuhk-arface_s2p")) == xm2vts_cuhk_arface + + + +def test04_dev_files_protocols(): + + cuhk = 112 + arface = 80 + xm2vts = 176 + all_mixed = 368 + cuhk_arface_xm2vts = arface + cuhk_xm2vts_arface = xm2vts + arface_cuhk_xm2vts = cuhk + arface_xm2vts_cuhk = xm2vts + xm2vts_cuhk_arface = cuhk + xm2vts_arface_cuhk = arface + + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="cuhk_p2s")) == cuhk + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="cuhk_s2p")) == cuhk + + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="arface_p2s")) == arface + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="arface_s2p")) == arface + + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="xm2vts_p2s")) == xm2vts + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="xm2vts_s2p")) == xm2vts + + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="all-mixed_p2s")) == all_mixed + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="all-mixed_s2p")) == all_mixed + + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="cuhk-arface-xm2vts_p2s")) == cuhk_arface_xm2vts + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="cuhk-arface-xm2vts_s2p")) == cuhk_arface_xm2vts + + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="cuhk-xm2vts-arface_p2s")) == cuhk_xm2vts_arface + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="cuhk-xm2vts-arface_s2p")) == cuhk_xm2vts_arface + + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="arface-xm2vts-cuhk_p2s")) == arface_xm2vts_cuhk + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="arface-xm2vts-cuhk_s2p")) == arface_xm2vts_cuhk + + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="arface-cuhk-xm2vts_p2s")) == arface_cuhk_xm2vts + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="arface-cuhk-xm2vts_s2p")) == arface_cuhk_xm2vts + + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="xm2vts-cuhk-arface_p2s")) == xm2vts_cuhk_arface + assert len(bob.db.cuhk.Database().objects(groups='dev', protocol="xm2vts-cuhk-arface_s2p")) == xm2vts_cuhk_arface + + + +def test05_eval_files_protocols(): + + cuhk = 114 + arface = 78 + xm2vts = 178 + all_mixed = 370 + cuhk_arface_xm2vts = xm2vts + cuhk_xm2vts_arface = arface + arface_cuhk_xm2vts = xm2vts + arface_xm2vts_cuhk = cuhk + xm2vts_cuhk_arface = arface + xm2vts_arface_cuhk = cuhk + + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="cuhk_p2s")) == cuhk + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="cuhk_s2p")) == cuhk + + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="arface_p2s")) == arface + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="arface_s2p")) == arface + + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="xm2vts_p2s")) == xm2vts + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="xm2vts_s2p")) == xm2vts + + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="all-mixed_p2s")) == all_mixed + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="all-mixed_s2p")) == all_mixed + + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="cuhk-arface-xm2vts_p2s")) == cuhk_arface_xm2vts + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="cuhk-arface-xm2vts_s2p")) == cuhk_arface_xm2vts + + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="cuhk-xm2vts-arface_p2s")) == cuhk_xm2vts_arface + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="cuhk-xm2vts-arface_s2p")) == cuhk_xm2vts_arface + + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="arface-xm2vts-cuhk_p2s")) == arface_xm2vts_cuhk + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="arface-xm2vts-cuhk_s2p")) == arface_xm2vts_cuhk + + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="arface-cuhk-xm2vts_p2s")) == arface_cuhk_xm2vts + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="arface-cuhk-xm2vts_s2p")) == arface_cuhk_xm2vts + + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="xm2vts-cuhk-arface_p2s")) == xm2vts_cuhk_arface + assert len(bob.db.cuhk.Database().objects(groups='eval', protocol="xm2vts-cuhk-arface_s2p")) == xm2vts_cuhk_arface + + +def test06_dev_enrol_files_protocols(): + + cuhk = 56 + arface = 40 + xm2vts = 88 + all_mixed = 184 + cuhk_arface_xm2vts = arface + cuhk_xm2vts_arface = xm2vts + arface_cuhk_xm2vts = cuhk + arface_xm2vts_cuhk = xm2vts + xm2vts_cuhk_arface = cuhk + xm2vts_arface_cuhk = arface + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="cuhk_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk_p2s"))== cuhk + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="cuhk_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk_s2p")) == cuhk + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="arface_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface_p2s")) == arface + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="arface_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface_s2p")) == arface + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="xm2vts_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts_p2s")) == xm2vts + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="xm2vts_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts_s2p")) == xm2vts + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="all-mixed_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="all-mixed_p2s")) == all_mixed + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="all-mixed_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="all-mixed_s2p")) == all_mixed + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="cuhk-arface-xm2vts_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-arface-xm2vts_p2s")) == cuhk_arface_xm2vts + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="cuhk-arface-xm2vts_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-arface-xm2vts_s2p")) == cuhk_arface_xm2vts + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="cuhk-xm2vts-arface_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-xm2vts-arface_p2s")) == cuhk_xm2vts_arface + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="cuhk-xm2vts-arface_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-xm2vts-arface_s2p")) == cuhk_xm2vts_arface + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="arface-xm2vts-cuhk_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-xm2vts-cuhk_p2s")) == arface_xm2vts_cuhk + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="arface-xm2vts-cuhk_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-xm2vts-cuhk_s2p")) == arface_xm2vts_cuhk + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="arface-cuhk-xm2vts_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-cuhk-xm2vts_p2s")) == arface_cuhk_xm2vts + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="arface-cuhk-xm2vts_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-cuhk-xm2vts_s2p")) == arface_cuhk_xm2vts + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="xm2vts-cuhk-arface_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts-cuhk-arface_p2s")) == xm2vts_cuhk_arface + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='dev', protocol="xm2vts-cuhk-arface_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts-cuhk-arface_s2p")) == xm2vts_cuhk_arface + + + +def test07_eval_enrol_files_protocols(): + + cuhk = 57 + arface = 39 + xm2vts = 89 + all_mixed = 185 + cuhk_arface_xm2vts = xm2vts + cuhk_xm2vts_arface = arface + arface_cuhk_xm2vts = xm2vts + arface_xm2vts_cuhk = cuhk + xm2vts_cuhk_arface = arface + xm2vts_arface_cuhk = cuhk + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="cuhk_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk_p2s", groups='eval'))== cuhk + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="cuhk_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk_s2p", groups='eval')) == cuhk + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="arface_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface_p2s", groups='eval')) == arface + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="arface_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface_s2p", groups='eval')) == arface + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="xm2vts_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts_p2s", groups='eval')) == xm2vts + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="xm2vts_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts_s2p", groups='eval')) == xm2vts + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="all-mixed_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="all-mixed_p2s", groups='eval')) == all_mixed + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="all-mixed_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="all-mixed_s2p", groups='eval')) == all_mixed + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="cuhk-arface-xm2vts_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-arface-xm2vts_p2s", groups='eval')) == cuhk_arface_xm2vts + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="cuhk-arface-xm2vts_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-arface-xm2vts_s2p", groups='eval')) == cuhk_arface_xm2vts + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="cuhk-xm2vts-arface_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-xm2vts-arface_p2s", groups='eval')) == cuhk_xm2vts_arface + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="cuhk-xm2vts-arface_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="cuhk-xm2vts-arface_s2p", groups='eval')) == cuhk_xm2vts_arface + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="arface-xm2vts-cuhk_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-xm2vts-cuhk_p2s", groups='eval')) == arface_xm2vts_cuhk + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="arface-xm2vts-cuhk_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-xm2vts-cuhk_s2p", groups='eval')) == arface_xm2vts_cuhk + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="arface-cuhk-xm2vts_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-cuhk-xm2vts_p2s", groups='eval')) == arface_cuhk_xm2vts + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="arface-cuhk-xm2vts_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="arface-cuhk-xm2vts_s2p", groups='eval')) == arface_cuhk_xm2vts + + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="xm2vts-cuhk-arface_p2s")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts-cuhk-arface_p2s", groups='eval')) == xm2vts_cuhk_arface + assert len(bob.db.cuhk.Database().objects(purposes='enroll', groups='eval', protocol="xm2vts-cuhk-arface_s2p")) == len(bob.db.cuhk.Database().enroll_files(protocol="xm2vts-cuhk-arface_s2p", groups='eval')) == xm2vts_cuhk_arface + + +def test08_strings(): + + db = bob.db.cuhk.Database() + + for p in PROTOCOLS: + for g in GROUPS: + for u in PURPOSES: + files = db.objects(purposes=u, groups=g, protocol=p) + + for f in files: + #Checking if the strings are correct + assert f.purpose == u + assert f.protocol == p + assert f.group == g + + +def test09_annotations(): + + db = bob.db.cuhk.Database() + + for p in PROTOCOLS: + for f in db.objects(protocol=p): + + assert len(f.annotations(annotation_type=""))==35 #ALL ANNOTATIONS + + assert f.annotations()["reye"][0] > 0 + assert f.annotations()["reye"][1] > 0 + + assert f.annotations()["leye"][0] > 0 + assert f.annotations()["leye"][1] > 0 + + + + + + diff --git a/bob/db/cuhk/utils.py b/bob/db/cuhk/utils.py index 8144448b1b50f96f49703d59ede6c4d17986342f..f763a714f3eb167d4274ec39c9f6be08e2d436df 100644 --- a/bob/db/cuhk/utils.py +++ b/bob/db/cuhk/utils.py @@ -109,12 +109,15 @@ class ARFACEWrapper(): #Reading the annotation file original_annotations = read_annotations(path) + index = 0 for a in original_annotations: annotations.append(bob.db.cuhk.Annotation(o.id, a[0], - a[1] + a[1], + index = index )) + index += 1 return annotations @@ -292,7 +295,7 @@ class XM2VTSWrapper(): files = [] for c in clients: cuhk_files = cuhk.query(bob.db.cuhk.File).join(bob.db.cuhk.Client).filter(bob.db.cuhk.Client.original_id==c.id) - print "{0} = {1}".format(c.id, cuhk_files.count()) + #print "{0} = {1}".format(c.id, cuhk_files.count()) for f in cuhk_files: files.append(f) @@ -314,18 +317,20 @@ class XM2VTSWrapper(): if(o.modality=="sketch"): path = os.path.join(annotation_dir, o.path) + annotation_extension else: - #import ipdb; ipdb.set_trace(); file_name = o.path.split("/")[2] #THE ORIGINAL XM2VTS RELATIVE PATH IS: XXX\XXX\XXX path = os.path.join(annotation_dir,"xm2vts", "photo", file_name) + "_f02" + annotation_extension #FOR SOME REASON THE AUTHORS SET THIS '_f02 IN THE END OF THE FILE' #Reading the annotation file original_annotations = read_annotations(path) + index = 0 for a in original_annotations: annotations.append(bob.db.cuhk.Annotation(o.id, a[0], - a[1] + a[1], + index = index )) + index += 1 return annotations @@ -442,11 +447,15 @@ class CUHKWrapper(): #Reading the annotation file original_annotations = read_annotations(path) + index = 0 for a in original_annotations: annotations.append(bob.db.cuhk.Annotation(o.id, a[0], - a[1] + a[1], + index = index )) + index += 1 + return annotations