Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.db.utfvp
Commits
ef9b1c84
Commit
ef9b1c84
authored
May 11, 2017
by
Amir MOHAMMADI
Browse files
The SQLiteDatabase now accepts original_directory and original_extension
parent
16cb7f30
Pipeline
#9191
passed with stages
in 40 minutes and 29 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/db/utfvp/query.py
View file @
ef9b1c84
...
...
@@ -15,6 +15,7 @@ import bob.db.base
SQLITE_FILE
=
Interface
().
files
()[
0
]
class
Database
(
bob
.
db
.
base
.
SQLiteDatabase
):
"""The dataset class opens and maintains a connection opened to the Database.
...
...
@@ -22,16 +23,17 @@ class Database(bob.db.base.SQLiteDatabase):
and for the data itself inside the database.
"""
def
__init__
(
self
):
bob
.
db
.
base
.
SQLiteDatabase
.
__init__
(
self
,
SQLITE_FILE
,
File
)
def
__init__
(
self
,
original_directory
=
None
,
original_extension
=
None
):
bob
.
db
.
base
.
SQLiteDatabase
.
__init__
(
self
,
SQLITE_FILE
,
File
,
original_directory
,
original_extension
)
def
groups
(
self
,
protocol
=
None
):
"""Returns the names of all registered groups"""
if
protocol
==
'1vsall'
:
return
(
'world'
,
'dev'
)
else
:
return
(
'world'
,
'dev'
,
'eval'
)
if
protocol
==
'1vsall'
:
return
(
'world'
,
'dev'
)
else
:
return
(
'world'
,
'dev'
,
'eval'
)
def
clients
(
self
,
protocol
=
None
,
groups
=
None
):
"""Returns a set of clients for the specific query by the user.
...
...
@@ -60,18 +62,21 @@ class Database(bob.db.base.SQLiteDatabase):
"""
protocols
=
self
.
check_parameters_for_validity
(
protocol
,
"protocol"
,
self
.
protocol_names
())
protocols
=
self
.
check_parameters_for_validity
(
protocol
,
"protocol"
,
self
.
protocol_names
())
groups
=
self
.
check_parameters_for_validity
(
groups
,
"group"
,
self
.
groups
())
retval
=
[]
# List of the clients
if
'world'
in
groups
:
q
=
self
.
query
(
Client
).
join
((
File
,
Client
.
files
)).
join
((
Protocol
,
File
.
protocols_train
)).
filter
(
Protocol
.
name
.
in_
(
protocols
))
q
=
self
.
query
(
Client
).
join
((
File
,
Client
.
files
)).
join
(
(
Protocol
,
File
.
protocols_train
)).
filter
(
Protocol
.
name
.
in_
(
protocols
))
q
=
q
.
order_by
(
Client
.
id
)
retval
+=
list
(
q
)
if
'dev'
in
groups
or
'eval'
in
groups
:
q
=
self
.
query
(
Client
).
join
((
Model
,
Client
.
models
)).
join
((
Protocol
,
Model
.
protocol
)).
filter
(
Protocol
.
name
.
in_
(
protocols
))
q
=
self
.
query
(
Client
).
join
((
Model
,
Client
.
models
)).
join
(
(
Protocol
,
Model
.
protocol
)).
filter
(
Protocol
.
name
.
in_
(
protocols
))
q
=
q
.
filter
(
Model
.
sgroup
.
in_
(
groups
))
q
=
q
.
order_by
(
Client
.
id
)
retval
+=
list
(
q
)
...
...
@@ -81,7 +86,6 @@ class Database(bob.db.base.SQLiteDatabase):
return
retval
def
client_ids
(
self
,
protocol
=
None
,
groups
=
None
):
"""Returns a set of client ids for the specific query by the user.
...
...
@@ -114,7 +118,6 @@ class Database(bob.db.base.SQLiteDatabase):
return
[
client
.
id
for
client
in
self
.
clients
(
protocol
,
groups
)]
def
models
(
self
,
protocol
=
None
,
groups
=
None
):
"""Returns a set of models for the specific query by the user.
...
...
@@ -144,19 +147,20 @@ class Database(bob.db.base.SQLiteDatabase):
"""
protocols
=
self
.
check_parameters_for_validity
(
protocol
,
"protocol"
,
self
.
protocol_names
())
protocols
=
self
.
check_parameters_for_validity
(
protocol
,
"protocol"
,
self
.
protocol_names
())
groups
=
self
.
check_parameters_for_validity
(
groups
,
"group"
,
self
.
groups
())
retval
=
[]
if
'dev'
in
groups
or
'eval'
in
groups
:
# List of the clients
q
=
self
.
query
(
Model
).
join
((
Protocol
,
Model
.
protocol
)).
filter
(
Protocol
.
name
.
in_
(
protocols
))
q
=
self
.
query
(
Model
).
join
((
Protocol
,
Model
.
protocol
)
).
filter
(
Protocol
.
name
.
in_
(
protocols
))
q
=
q
.
filter
(
Model
.
sgroup
.
in_
(
groups
)).
order_by
(
Model
.
name
)
retval
+=
list
(
q
)
return
retval
def
model_ids
(
self
,
protocol
=
None
,
groups
=
None
):
"""Returns a set of models ids for the specific query by the user.
...
...
@@ -189,19 +193,16 @@ class Database(bob.db.base.SQLiteDatabase):
return
[
model
.
name
for
model
in
self
.
models
(
protocol
,
groups
)]
def
has_client_id
(
self
,
id
):
"""Returns True if we have a client with a certain integer identifier"""
return
self
.
query
(
Client
).
filter
(
Client
.
id
==
id
).
count
()
!=
0
return
self
.
query
(
Client
).
filter
(
Client
.
id
==
id
).
count
()
!=
0
def
client
(
self
,
id
):
"""Returns the client object in the database given a certain id. Raises
an error if that does not exist."""
return
self
.
query
(
Client
).
filter
(
Client
.
id
==
id
).
one
()
return
self
.
query
(
Client
).
filter
(
Client
.
id
==
id
).
one
()
def
get_client_id_from_model_id
(
self
,
model_id
):
"""Returns the client_id attached to the given model_id
...
...
@@ -217,11 +218,10 @@ class Database(bob.db.base.SQLiteDatabase):
"""
return
self
.
query
(
Model
).
filter
(
Model
.
name
==
model_id
).
first
().
client_id
return
self
.
query
(
Model
).
filter
(
Model
.
name
==
model_id
).
first
().
client_id
def
objects
(
self
,
protocol
=
None
,
purposes
=
None
,
model_ids
=
None
,
groups
=
None
,
classes
=
None
,
finger_ids
=
None
,
session_ids
=
None
):
classes
=
None
,
finger_ids
=
None
,
session_ids
=
None
):
"""Returns a set of Files for the specific query by the user.
...
...
@@ -277,10 +277,13 @@ class Database(bob.db.base.SQLiteDatabase):
"""
protocols
=
self
.
check_parameters_for_validity
(
protocol
,
"protocol"
,
self
.
protocol_names
())
purposes
=
self
.
check_parameters_for_validity
(
purposes
,
"purpose"
,
self
.
purposes
())
protocols
=
self
.
check_parameters_for_validity
(
protocol
,
"protocol"
,
self
.
protocol_names
())
purposes
=
self
.
check_parameters_for_validity
(
purposes
,
"purpose"
,
self
.
purposes
())
groups
=
self
.
check_parameters_for_validity
(
groups
,
"group"
,
self
.
groups
())
classes
=
self
.
check_parameters_for_validity
(
classes
,
"class"
,
(
'client'
,
'impostor'
))
classes
=
self
.
check_parameters_for_validity
(
classes
,
"class"
,
(
'client'
,
'impostor'
))
from
six
import
string_types
if
model_ids
is
None
:
...
...
@@ -301,50 +304,60 @@ class Database(bob.db.base.SQLiteDatabase):
retval
=
[]
if
'world'
in
groups
:
q
=
self
.
query
(
File
).
join
((
Protocol
,
File
.
protocols_train
)).
\
filter
(
Protocol
.
name
.
in_
(
protocols
))
if
finger_ids
:
q
=
q
.
filter
(
File
.
finger_id
.
in_
(
finger_ids
))
if
session_ids
:
q
=
q
.
filter
(
File
.
session_id
.
in_
(
session_ids
))
filter
(
Protocol
.
name
.
in_
(
protocols
))
if
finger_ids
:
q
=
q
.
filter
(
File
.
finger_id
.
in_
(
finger_ids
))
if
session_ids
:
q
=
q
.
filter
(
File
.
session_id
.
in_
(
session_ids
))
q
=
q
.
order_by
(
File
.
client_id
,
File
.
finger_id
,
File
.
session_id
)
retval
+=
list
(
q
)
if
'dev'
in
groups
or
'eval'
in
groups
:
sgroups
=
[]
if
'dev'
in
groups
:
sgroups
.
append
(
'dev'
)
if
'eval'
in
groups
:
sgroups
.
append
(
'eval'
)
if
'dev'
in
groups
:
sgroups
.
append
(
'dev'
)
if
'eval'
in
groups
:
sgroups
.
append
(
'eval'
)
if
'enroll'
in
purposes
:
q
=
self
.
query
(
File
).
join
(
Client
).
join
((
Model
,
File
.
models_enroll
)).
join
((
Protocol
,
Model
.
protocol
)).
\
filter
(
and_
(
Protocol
.
name
.
in_
(
protocols
),
Model
.
sgroup
.
in_
(
sgroups
)))
filter
(
and_
(
Protocol
.
name
.
in_
(
protocols
),
Model
.
sgroup
.
in_
(
sgroups
)))
if
model_ids
:
q
=
q
.
filter
(
Model
.
name
.
in_
(
model_ids
))
if
finger_ids
:
q
=
q
.
filter
(
File
.
finger_id
.
in_
(
finger_ids
))
if
session_ids
:
q
=
q
.
filter
(
File
.
session_id
.
in_
(
session_ids
))
if
finger_ids
:
q
=
q
.
filter
(
File
.
finger_id
.
in_
(
finger_ids
))
if
session_ids
:
q
=
q
.
filter
(
File
.
session_id
.
in_
(
session_ids
))
q
=
q
.
order_by
(
File
.
client_id
,
File
.
finger_id
,
File
.
session_id
)
retval
+=
list
(
q
)
if
'probe'
in
purposes
:
if
'client'
in
classes
:
q
=
self
.
query
(
File
).
join
(
Client
).
join
((
Model
,
File
.
models_probe
)).
join
((
Protocol
,
Model
.
protocol
)).
\
filter
(
and_
(
Protocol
.
name
.
in_
(
protocols
),
Model
.
sgroup
.
in_
(
sgroups
),
File
.
client_id
==
Model
.
client_id
))
filter
(
and_
(
Protocol
.
name
.
in_
(
protocols
),
Model
.
sgroup
.
in_
(
sgroups
),
File
.
client_id
==
Model
.
client_id
))
if
model_ids
:
q
=
q
.
filter
(
Model
.
name
.
in_
(
model_ids
))
if
finger_ids
:
q
=
q
.
filter
(
File
.
finger_id
.
in_
(
finger_ids
))
if
session_ids
:
q
=
q
.
filter
(
File
.
session_id
.
in_
(
session_ids
))
if
finger_ids
:
q
=
q
.
filter
(
File
.
finger_id
.
in_
(
finger_ids
))
if
session_ids
:
q
=
q
.
filter
(
File
.
session_id
.
in_
(
session_ids
))
q
=
q
.
order_by
(
File
.
client_id
,
File
.
finger_id
,
File
.
session_id
)
retval
+=
list
(
q
)
if
'impostor'
in
classes
:
q
=
self
.
query
(
File
).
join
(
Client
).
join
((
Model
,
File
.
models_probe
)).
join
((
Protocol
,
Model
.
protocol
)).
\
filter
(
and_
(
Protocol
.
name
.
in_
(
protocols
),
Model
.
sgroup
.
in_
(
sgroups
),
File
.
client_id
!=
Model
.
client_id
))
filter
(
and_
(
Protocol
.
name
.
in_
(
protocols
),
Model
.
sgroup
.
in_
(
sgroups
),
File
.
client_id
!=
Model
.
client_id
))
if
len
(
model_ids
)
!=
0
:
q
=
q
.
filter
(
Model
.
name
.
in_
(
model_ids
))
if
finger_ids
:
q
=
q
.
filter
(
File
.
finger_id
.
in_
(
finger_ids
))
if
session_ids
:
q
=
q
.
filter
(
File
.
session_id
.
in_
(
session_ids
))
if
finger_ids
:
q
=
q
.
filter
(
File
.
finger_id
.
in_
(
finger_ids
))
if
session_ids
:
q
=
q
.
filter
(
File
.
session_id
.
in_
(
session_ids
))
q
=
q
.
order_by
(
File
.
client_id
,
File
.
finger_id
,
File
.
session_id
)
retval
+=
list
(
q
)
return
list
(
set
(
retval
))
# To remove duplicates
return
list
(
set
(
retval
))
# To remove duplicates
def
protocol_names
(
self
):
"""Returns all registered protocol names"""
...
...
@@ -353,25 +366,21 @@ class Database(bob.db.base.SQLiteDatabase):
retval
=
[
str
(
k
.
name
)
for
k
in
l
]
return
retval
def
protocols
(
self
):
"""Returns all registered protocols"""
return
list
(
self
.
query
(
Protocol
))
def
has_protocol
(
self
,
name
):
"""Tells if a certain protocol is available"""
return
self
.
query
(
Protocol
).
filter
(
Protocol
.
name
==
name
).
count
()
!=
0
return
self
.
query
(
Protocol
).
filter
(
Protocol
.
name
==
name
).
count
()
!=
0
def
protocol
(
self
,
name
):
"""Returns the protocol object in the database given a certain name. Raises
an error if that does not exist."""
return
self
.
query
(
Protocol
).
filter
(
Protocol
.
name
==
name
).
one
()
return
self
.
query
(
Protocol
).
filter
(
Protocol
.
name
==
name
).
one
()
def
purposes
(
self
):
return
(
'train'
,
'enroll'
,
'probe'
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment