Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.db.cuhk_cufs
Commits
c48db237
Commit
c48db237
authored
Aug 17, 2015
by
Tiago de Freitas Pereira
Browse files
Finished unit tests
parent
ed9a02b1
Changes
5
Hide whitespace changes
Inline
Side-by-side
bob/db/cuhk/create.py
View file @
c48db237
...
...
@@ -208,15 +208,6 @@ def add_protocols(session, verbose, photo2sketch=True):
"""
PROTOCOLS
=
(
'cuhk_p2s'
,
'arface_p2s'
,
'xm2vts_p2s'
,
'all-mixed_p2s'
,
'cuhk-arface-xm2vts_p2s'
,
'cuhk-xm2vts-arface_p2s'
,
'arface-cuhk-xm2vts_p2s'
,
'arface-xm2vts-cuhk_p2s'
,
'xm2vts-cuhk-arface_p2s'
,
'xm2vts-arface-cuhk_p2s'
,
'cuhk_s2p'
,
'arface_s2p'
,
'xm2vts_s2p'
,
'all-mixed_s2p'
,
'cuhk-arface-xm2vts_s2p'
,
'cuhk-xm2vts-arface_s2p'
,
'arface-cuhk-xm2vts_s2p'
,
'arface-xm2vts-cuhk_s2p'
,
'xm2vts-cuhk-arface_s2p'
,
'xm2vts-arface-cuhk_s2p'
)
GROUPS
=
(
'world'
,
'dev'
,
'eval'
)
PURPOSES
=
(
'train'
,
'enrol'
,
'probe'
)
arface
=
ARFACEWrapper
()
xm2vts
=
XM2VTSWrapper
()
cuhk
=
CUHKWrapper
()
...
...
@@ -417,11 +408,20 @@ def insert_protocol_data(session, protocol, group, purpose, file_objects, photo2
for
f
in
file_objects
:
if
purpose
!=
"train"
:
if
photo2sketch
and
f
.
modality
==
"photo"
:
purpose
=
"enrol"
if
photo2sketch
:
if
f
.
modality
==
"photo"
:
purpose
=
"enroll"
else
:
purpose
=
"probe"
else
:
purpose
=
"probe"
if
f
.
modality
==
"photo"
:
purpose
=
"probe"
else
:
purpose
=
"enroll"
session
.
add
(
bob
.
db
.
cuhk
.
Protocol_File_Association
(
protocol
,
group
,
purpose
,
f
.
id
))
...
...
bob/db/cuhk/models.py
View file @
c48db237
...
...
@@ -46,7 +46,7 @@ PROTOCOLS = ('cuhk_p2s', 'arface_p2s', 'xm2vts_p2s', 'all-mixed_p2s', 'cuhk-arfa
GROUPS
=
(
'world'
,
'dev'
,
'eval'
)
PURPOSES
=
(
'train'
,
'enrol'
,
'probe'
)
PURPOSES
=
(
'train'
,
'enrol
l
'
,
'probe'
)
class
Protocol_File_Association
(
Base
):
...
...
@@ -94,6 +94,36 @@ class Client(Base):
return
"<Client({0}, {1}, {2})>"
.
format
(
self
.
id
,
self
.
original_database
,
self
.
original_id
)
class
Annotation
(
Base
):
"""
The CUHK-CUFS provides 35 coordinates.
To model this coordinates this table was built.
The columns are the following:
- Annotation.id
- x
- y
"""
__tablename__
=
'annotation'
file_id
=
Column
(
Integer
,
ForeignKey
(
'file.id'
),
primary_key
=
True
)
x
=
Column
(
Integer
,
primary_key
=
True
)
y
=
Column
(
Integer
,
primary_key
=
True
)
index
=
Column
(
Integer
)
def
__init__
(
self
,
file_id
,
x
,
y
,
index
=
0
):
self
.
file_id
=
file_id
self
.
x
=
x
self
.
y
=
y
self
.
index
=
index
def
__repr__
(
self
):
return
"<Annotation(file_id:{0}, index:{1}, y={2}, x={3})>"
.
format
(
self
.
file_id
,
self
.
index
,
self
.
y
,
self
.
x
)
class
File
(
Base
,
bob
.
db
.
verification
.
utils
.
File
):
"""
Information about the files of the CUHK-CUFS database.
...
...
@@ -112,7 +142,7 @@ class File(Base, bob.db.verification.utils.File):
# a back-reference from the client class to a list of files
client
=
relationship
(
"Client"
,
backref
=
backref
(
"files"
,
order_by
=
id
))
all_annotations
=
relationship
(
"Annotation"
,
backref
=
backref
(
"file"
),
uselist
=
True
)
all_annotations
=
relationship
(
"Annotation"
,
backref
=
backref
(
"file"
),
uselist
=
True
,
order_by
=
Annotation
.
index
)
def
__init__
(
self
,
id
,
image_name
,
client_id
,
modality
):
# call base class constructor
...
...
@@ -133,26 +163,5 @@ class File(Base, bob.db.verification.utils.File):
class
Annotation
(
Base
):
"""
The CUHK-CUFS provides 35 coordinates.
To model this coordinates this table was built.
The columns are the following:
- Annotation.id
- x
- y
"""
__tablename__
=
'annotation'
file_id
=
Column
(
Integer
,
ForeignKey
(
'file.id'
),
primary_key
=
True
)
x
=
Column
(
Integer
,
primary_key
=
True
)
y
=
Column
(
Integer
,
primary_key
=
True
)
def
__init__
(
self
,
file_id
,
x
,
y
):
self
.
file_id
=
file_id
self
.
x
=
x
self
.
y
=
y
bob/db/cuhk/query.py
View file @
c48db237
...
...
@@ -40,6 +40,13 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
bob
.
db
.
verification
.
utils
.
SQLiteDatabase
.
__init__
(
self
,
SQLITE_FILE
,
File
)
bob
.
db
.
verification
.
utils
.
ZTDatabase
.
__init__
(
self
,
original_directory
=
original_directory
,
original_extension
=
original_extension
)
def
protocols
(
self
):
return
PROTOCOLS
def
purposes
(
self
):
return
PURPOSES
def
objects
(
self
,
groups
=
None
,
protocol
=
None
,
purposes
=
None
,
model_ids
=
None
,
**
kwargs
):
"""
...
...
@@ -57,7 +64,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
raise
ValueError
(
"Please, select only one of the following protocols {0}"
.
format
(
protocols
))
#Querying
query
=
self
.
query
(
bob
.
db
.
cuhk
.
File
).
join
(
bob
.
db
.
cuhk
.
Protocol_File_Association
).
join
(
bob
.
db
.
cuhk
.
Client
)
query
=
self
.
query
(
bob
.
db
.
cuhk
.
File
,
bob
.
db
.
cuhk
.
Protocol_File_Association
).
join
(
bob
.
db
.
cuhk
.
Protocol_File_Association
).
join
(
bob
.
db
.
cuhk
.
Client
)
#filtering
query
=
query
.
filter
(
bob
.
db
.
cuhk
.
Protocol_File_Association
.
group
.
in_
(
groups
))
...
...
@@ -70,7 +77,15 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
query
=
query
.
filter
(
bob
.
db
.
cuhk
.
Client
.
id
.
in_
(
model_ids
))
return
query
.
all
()
raw_files
=
query
.
all
()
files
=
[]
for
f
in
raw_files
:
f
[
0
].
group
=
f
[
1
].
group
f
[
0
].
purpose
=
f
[
1
].
purpose
f
[
0
].
protocol
=
f
[
1
].
protocol
files
.
append
(
f
[
0
])
return
files
def
model_ids
(
self
,
protocol
=
None
,
groups
=
None
):
...
...
bob/db/cuhk/test.py
View file @
c48db237
...
...
@@ -21,14 +21,320 @@
"""
import
bob.db.cuhk
possible_protocols
=
[
"cuhk"
]
#
possible_protocols = ["cuhk"]
""" Defining protocols. Yes, they are static """
PROTOCOLS
=
(
'cuhk_p2s'
,
'arface_p2s'
,
'xm2vts_p2s'
,
'all-mixed_p2s'
,
'cuhk-arface-xm2vts_p2s'
,
'cuhk-xm2vts-arface_p2s'
,
'arface-cuhk-xm2vts_p2s'
,
'arface-xm2vts-cuhk_p2s'
,
'xm2vts-cuhk-arface_p2s'
,
'xm2vts-arface-cuhk_p2s'
,
'cuhk_s2p'
,
'arface_s2p'
,
'xm2vts_s2p'
,
'all-mixed_s2p'
,
'cuhk-arface-xm2vts_s2p'
,
'cuhk-xm2vts-arface_s2p'
,
'arface-cuhk-xm2vts_s2p'
,
'arface-xm2vts-cuhk_s2p'
,
'xm2vts-cuhk-arface_s2p'
,
'xm2vts-arface-cuhk_s2p'
)
GROUPS
=
(
'world'
,
'dev'
,
'eval'
)
PURPOSES
=
(
'train'
,
'enroll'
,
'probe'
)
def
test_protocols
():
import
os
db
=
bob
.
db
.
cuhk
.
Database
()
available_protocols
=
os
.
listdir
(
db
.
get_base_directory
())
def
test01_protocols_purposes_groups
():
#testing protocols
possible_protocols
=
bob
.
db
.
cuhk
.
Database
().
protocols
()
for
p
in
possible_protocols
:
assert
p
in
available_protocols
assert
p
in
PROTOCOLS
#testing purposes
possible_purposes
=
bob
.
db
.
cuhk
.
Database
().
purposes
()
for
p
in
possible_purposes
:
assert
p
in
PURPOSES
#testing GROUPS
possible_groups
=
bob
.
db
.
cuhk
.
Database
().
groups
()
for
p
in
possible_groups
:
assert
p
in
GROUPS
def
test02_all_files_protocols
():
cuhk
=
376
arface
=
246
xm2vts
=
590
all_mixed
=
1212
cuhk_arface_xm2vts
=
408
cuhk_xm2vts_arface
=
404
arface_cuhk_xm2vts
=
378
arface_xm2vts_cuhk
=
378
xm2vts_cuhk_arface
=
426
xm2vts_arface_cuhk
=
430
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"cuhk_p2s"
))
==
cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"cuhk_s2p"
))
==
cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"arface_p2s"
))
==
arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"arface_s2p"
))
==
arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"xm2vts_p2s"
))
==
xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"xm2vts_s2p"
))
==
xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"all-mixed_p2s"
))
==
all_mixed
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"all-mixed_s2p"
))
==
all_mixed
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"cuhk-arface-xm2vts_p2s"
))
==
cuhk_arface_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"cuhk-arface-xm2vts_s2p"
))
==
cuhk_arface_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"cuhk-xm2vts-arface_p2s"
))
==
cuhk_xm2vts_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"cuhk-xm2vts-arface_s2p"
))
==
cuhk_xm2vts_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"arface-xm2vts-cuhk_p2s"
))
==
arface_xm2vts_cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"arface-xm2vts-cuhk_s2p"
))
==
arface_xm2vts_cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"arface-cuhk-xm2vts_p2s"
))
==
arface_cuhk_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"arface-cuhk-xm2vts_s2p"
))
==
arface_cuhk_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"xm2vts-cuhk-arface_p2s"
))
==
xm2vts_cuhk_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
protocol
=
"xm2vts-cuhk-arface_s2p"
))
==
xm2vts_cuhk_arface
def
test03_world_files_protocols
():
cuhk
=
150
arface
=
88
xm2vts
=
236
all_mixed
=
474
cuhk_arface_xm2vts
=
cuhk
cuhk_xm2vts_arface
=
cuhk
arface_cuhk_xm2vts
=
arface
arface_xm2vts_cuhk
=
arface
xm2vts_cuhk_arface
=
xm2vts
xm2vts_arface_cuhk
=
xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"cuhk_p2s"
))
==
cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"cuhk_s2p"
))
==
cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"arface_p2s"
))
==
arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"arface_s2p"
))
==
arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"xm2vts_p2s"
))
==
xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"xm2vts_s2p"
))
==
xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"all-mixed_p2s"
))
==
all_mixed
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"all-mixed_s2p"
))
==
all_mixed
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"cuhk-arface-xm2vts_p2s"
))
==
cuhk_arface_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"cuhk-arface-xm2vts_s2p"
))
==
cuhk_arface_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"cuhk-xm2vts-arface_p2s"
))
==
cuhk_xm2vts_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"cuhk-xm2vts-arface_s2p"
))
==
cuhk_xm2vts_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"arface-xm2vts-cuhk_p2s"
))
==
arface_xm2vts_cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"arface-xm2vts-cuhk_s2p"
))
==
arface_xm2vts_cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"arface-cuhk-xm2vts_p2s"
))
==
arface_cuhk_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"arface-cuhk-xm2vts_s2p"
))
==
arface_cuhk_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"xm2vts-cuhk-arface_p2s"
))
==
xm2vts_cuhk_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'world'
,
protocol
=
"xm2vts-cuhk-arface_s2p"
))
==
xm2vts_cuhk_arface
def
test04_dev_files_protocols
():
cuhk
=
112
arface
=
80
xm2vts
=
176
all_mixed
=
368
cuhk_arface_xm2vts
=
arface
cuhk_xm2vts_arface
=
xm2vts
arface_cuhk_xm2vts
=
cuhk
arface_xm2vts_cuhk
=
xm2vts
xm2vts_cuhk_arface
=
cuhk
xm2vts_arface_cuhk
=
arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"cuhk_p2s"
))
==
cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"cuhk_s2p"
))
==
cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"arface_p2s"
))
==
arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"arface_s2p"
))
==
arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"xm2vts_p2s"
))
==
xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"xm2vts_s2p"
))
==
xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"all-mixed_p2s"
))
==
all_mixed
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"all-mixed_s2p"
))
==
all_mixed
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"cuhk-arface-xm2vts_p2s"
))
==
cuhk_arface_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"cuhk-arface-xm2vts_s2p"
))
==
cuhk_arface_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"cuhk-xm2vts-arface_p2s"
))
==
cuhk_xm2vts_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"cuhk-xm2vts-arface_s2p"
))
==
cuhk_xm2vts_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"arface-xm2vts-cuhk_p2s"
))
==
arface_xm2vts_cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"arface-xm2vts-cuhk_s2p"
))
==
arface_xm2vts_cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"arface-cuhk-xm2vts_p2s"
))
==
arface_cuhk_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"arface-cuhk-xm2vts_s2p"
))
==
arface_cuhk_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"xm2vts-cuhk-arface_p2s"
))
==
xm2vts_cuhk_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'dev'
,
protocol
=
"xm2vts-cuhk-arface_s2p"
))
==
xm2vts_cuhk_arface
def
test05_eval_files_protocols
():
cuhk
=
114
arface
=
78
xm2vts
=
178
all_mixed
=
370
cuhk_arface_xm2vts
=
xm2vts
cuhk_xm2vts_arface
=
arface
arface_cuhk_xm2vts
=
xm2vts
arface_xm2vts_cuhk
=
cuhk
xm2vts_cuhk_arface
=
arface
xm2vts_arface_cuhk
=
cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"cuhk_p2s"
))
==
cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"cuhk_s2p"
))
==
cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"arface_p2s"
))
==
arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"arface_s2p"
))
==
arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"xm2vts_p2s"
))
==
xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"xm2vts_s2p"
))
==
xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"all-mixed_p2s"
))
==
all_mixed
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"all-mixed_s2p"
))
==
all_mixed
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"cuhk-arface-xm2vts_p2s"
))
==
cuhk_arface_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"cuhk-arface-xm2vts_s2p"
))
==
cuhk_arface_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"cuhk-xm2vts-arface_p2s"
))
==
cuhk_xm2vts_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"cuhk-xm2vts-arface_s2p"
))
==
cuhk_xm2vts_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"arface-xm2vts-cuhk_p2s"
))
==
arface_xm2vts_cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"arface-xm2vts-cuhk_s2p"
))
==
arface_xm2vts_cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"arface-cuhk-xm2vts_p2s"
))
==
arface_cuhk_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"arface-cuhk-xm2vts_s2p"
))
==
arface_cuhk_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"xm2vts-cuhk-arface_p2s"
))
==
xm2vts_cuhk_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
groups
=
'eval'
,
protocol
=
"xm2vts-cuhk-arface_s2p"
))
==
xm2vts_cuhk_arface
def
test06_dev_enrol_files_protocols
():
cuhk
=
56
arface
=
40
xm2vts
=
88
all_mixed
=
184
cuhk_arface_xm2vts
=
arface
cuhk_xm2vts_arface
=
xm2vts
arface_cuhk_xm2vts
=
cuhk
arface_xm2vts_cuhk
=
xm2vts
xm2vts_cuhk_arface
=
cuhk
xm2vts_arface_cuhk
=
arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"cuhk_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"cuhk_p2s"
))
==
cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"cuhk_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"cuhk_s2p"
))
==
cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"arface_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"arface_p2s"
))
==
arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"arface_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"arface_s2p"
))
==
arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"xm2vts_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"xm2vts_p2s"
))
==
xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"xm2vts_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"xm2vts_s2p"
))
==
xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"all-mixed_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"all-mixed_p2s"
))
==
all_mixed
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"all-mixed_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"all-mixed_s2p"
))
==
all_mixed
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"cuhk-arface-xm2vts_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"cuhk-arface-xm2vts_p2s"
))
==
cuhk_arface_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"cuhk-arface-xm2vts_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"cuhk-arface-xm2vts_s2p"
))
==
cuhk_arface_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"cuhk-xm2vts-arface_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"cuhk-xm2vts-arface_p2s"
))
==
cuhk_xm2vts_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"cuhk-xm2vts-arface_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"cuhk-xm2vts-arface_s2p"
))
==
cuhk_xm2vts_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"arface-xm2vts-cuhk_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"arface-xm2vts-cuhk_p2s"
))
==
arface_xm2vts_cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"arface-xm2vts-cuhk_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"arface-xm2vts-cuhk_s2p"
))
==
arface_xm2vts_cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"arface-cuhk-xm2vts_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"arface-cuhk-xm2vts_p2s"
))
==
arface_cuhk_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"arface-cuhk-xm2vts_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"arface-cuhk-xm2vts_s2p"
))
==
arface_cuhk_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"xm2vts-cuhk-arface_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"xm2vts-cuhk-arface_p2s"
))
==
xm2vts_cuhk_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'dev'
,
protocol
=
"xm2vts-cuhk-arface_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"xm2vts-cuhk-arface_s2p"
))
==
xm2vts_cuhk_arface
def
test07_eval_enrol_files_protocols
():
cuhk
=
57
arface
=
39
xm2vts
=
89
all_mixed
=
185
cuhk_arface_xm2vts
=
xm2vts
cuhk_xm2vts_arface
=
arface
arface_cuhk_xm2vts
=
xm2vts
arface_xm2vts_cuhk
=
cuhk
xm2vts_cuhk_arface
=
arface
xm2vts_arface_cuhk
=
cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"cuhk_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"cuhk_p2s"
,
groups
=
'eval'
))
==
cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"cuhk_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"cuhk_s2p"
,
groups
=
'eval'
))
==
cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"arface_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"arface_p2s"
,
groups
=
'eval'
))
==
arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"arface_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"arface_s2p"
,
groups
=
'eval'
))
==
arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"xm2vts_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"xm2vts_p2s"
,
groups
=
'eval'
))
==
xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"xm2vts_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"xm2vts_s2p"
,
groups
=
'eval'
))
==
xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"all-mixed_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"all-mixed_p2s"
,
groups
=
'eval'
))
==
all_mixed
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"all-mixed_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"all-mixed_s2p"
,
groups
=
'eval'
))
==
all_mixed
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"cuhk-arface-xm2vts_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"cuhk-arface-xm2vts_p2s"
,
groups
=
'eval'
))
==
cuhk_arface_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"cuhk-arface-xm2vts_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"cuhk-arface-xm2vts_s2p"
,
groups
=
'eval'
))
==
cuhk_arface_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"cuhk-xm2vts-arface_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"cuhk-xm2vts-arface_p2s"
,
groups
=
'eval'
))
==
cuhk_xm2vts_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"cuhk-xm2vts-arface_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"cuhk-xm2vts-arface_s2p"
,
groups
=
'eval'
))
==
cuhk_xm2vts_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"arface-xm2vts-cuhk_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"arface-xm2vts-cuhk_p2s"
,
groups
=
'eval'
))
==
arface_xm2vts_cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"arface-xm2vts-cuhk_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"arface-xm2vts-cuhk_s2p"
,
groups
=
'eval'
))
==
arface_xm2vts_cuhk
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"arface-cuhk-xm2vts_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"arface-cuhk-xm2vts_p2s"
,
groups
=
'eval'
))
==
arface_cuhk_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"arface-cuhk-xm2vts_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"arface-cuhk-xm2vts_s2p"
,
groups
=
'eval'
))
==
arface_cuhk_xm2vts
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"xm2vts-cuhk-arface_p2s"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"xm2vts-cuhk-arface_p2s"
,
groups
=
'eval'
))
==
xm2vts_cuhk_arface
assert
len
(
bob
.
db
.
cuhk
.
Database
().
objects
(
purposes
=
'enroll'
,
groups
=
'eval'
,
protocol
=
"xm2vts-cuhk-arface_s2p"
))
==
len
(
bob
.
db
.
cuhk
.
Database
().
enroll_files
(
protocol
=
"xm2vts-cuhk-arface_s2p"
,
groups
=
'eval'
))
==
xm2vts_cuhk_arface
def
test08_strings
():
db
=
bob
.
db
.
cuhk
.
Database
()
for
p
in
PROTOCOLS
:
for
g
in
GROUPS
:
for
u
in
PURPOSES
:
files
=
db
.
objects
(
purposes
=
u
,
groups
=
g
,
protocol
=
p
)
for
f
in
files
:
#Checking if the strings are correct
assert
f
.
purpose
==
u
assert
f
.
protocol
==
p
assert
f
.
group
==
g
def
test09_annotations
():
db
=
bob
.
db
.
cuhk
.
Database
()
for
p
in
PROTOCOLS
:
for
f
in
db
.
objects
(
protocol
=
p
):
assert
len
(
f
.
annotations
(
annotation_type
=
""
))
==
35
#ALL ANNOTATIONS
assert
f
.
annotations
()[
"reye"
][
0
]
>
0
assert
f
.
annotations
()[
"reye"
][
1
]
>
0
assert
f
.
annotations
()[
"leye"
][
0
]
>
0
assert
f
.
annotations
()[
"leye"
][
1
]
>
0
bob/db/cuhk/utils.py
View file @
c48db237
...
...
@@ -109,12 +109,15 @@ class ARFACEWrapper():
#Reading the annotation file
original_annotations
=
read_annotations
(
path
)
index
=
0
for
a
in
original_annotations
:
annotations
.
append
(
bob
.
db
.
cuhk
.
Annotation
(
o
.
id
,
a
[
0
],
a
[
1
]
a
[
1
],
index
=
index
))
index
+=
1
return
annotations
...
...
@@ -292,7 +295,7 @@ class XM2VTSWrapper():
files
=
[]
for
c
in
clients
:
cuhk_files
=
cuhk
.
query
(
bob
.
db
.
cuhk
.
File
).
join
(
bob
.
db
.
cuhk
.
Client
).
filter
(
bob
.
db
.
cuhk
.
Client
.
original_id
==
c
.
id
)
print
"{0} = {1}"
.
format
(
c
.
id
,
cuhk_files
.
count
())
#
print "{0} = {1}".format(c.id, cuhk_files.count())
for
f
in
cuhk_files
:
files
.
append
(
f
)
...
...
@@ -314,18 +317,20 @@ class XM2VTSWrapper():
if
(
o
.
modality
==
"sketch"
):
path
=
os
.
path
.
join
(
annotation_dir
,
o
.
path
)
+
annotation_extension
else
:
#import ipdb; ipdb.set_trace();
file_name
=
o
.
path
.
split
(
"/"
)[
2
]
#THE ORIGINAL XM2VTS RELATIVE PATH IS: XXX\XXX\XXX
path
=
os
.
path
.
join
(
annotation_dir
,
"xm2vts"
,
"photo"
,
file_name
)
+
"_f02"
+
annotation_extension
#FOR SOME REASON THE AUTHORS SET THIS '_f02 IN THE END OF THE FILE'
#Reading the annotation file
original_annotations
=
read_annotations
(
path
)
index
=
0
for
a
in
original_annotations
:
annotations
.
append
(
bob
.
db
.
cuhk
.
Annotation
(
o
.
id
,
a
[
0
],
a
[
1
]
a
[
1
],
index
=
index
))
index
+=
1
return
annotations
...
...
@@ -442,11 +447,15 @@ class CUHKWrapper():
#Reading the annotation file
original_annotations
=
read_annotations
(
path
)
index
=
0
for
a
in
original_annotations
:
annotations
.
append
(
bob
.
db
.
cuhk
.
Annotation
(
o
.
id
,
a
[
0
],
a
[
1
]
a
[
1
],
index
=
index
))
index
+=
1
return
annotations
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment