Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.db.atnt
Commits
98ea7fb6
Commit
98ea7fb6
authored
Dec 07, 2012
by
Manuel Günther
Browse files
Based database on the novel xbob.db.verification.utils interface; some cleaned up.
parent
fbc21880
Changes
3
Hide whitespace changes
Inline
Side-by-side
setup.py
View file @
98ea7fb6
...
...
@@ -25,6 +25,7 @@ setup(
install_requires
=
[
'setuptools'
,
'bob'
,
# base signal proc./machine learning library
'xbob.db.verification.utils'
# defines a set of utilities for face verification databases like this one.
],
namespace_packages
=
[
...
...
xbob/db/atnt/models.py
View file @
98ea7fb6
...
...
@@ -25,6 +25,8 @@ with other xbob.db databases.
import
os
import
bob
import
xbob.db.verification.utils
class
Client
:
"""The clients of this database contain ONLY client ids. Nothing special."""
m_valid_client_ids
=
set
(
range
(
1
,
41
))
...
...
@@ -35,18 +37,18 @@ class Client:
class
File
:
class
File
(
xbob
.
db
.
verification
.
utils
.
File
)
:
"""Files of this database are composed from the client id and a file id."""
m_valid_file_ids
=
set
(
range
(
1
,
11
))
def
__init__
(
self
,
client_id
,
client_file_id
):
assert
client_file_id
in
self
.
m_valid_file_ids
# compute the file id on the fly
self
.
id
=
(
client_id
-
1
)
*
len
(
self
.
m_valid_file_ids
)
+
client_file_id
# copy client id
self
.
client_id
=
client_id
file_id
=
(
client_id
-
1
)
*
len
(
self
.
m_valid_file_ids
)
+
client_file_id
# generate path on the fly
self
.
path
=
os
.
path
.
join
(
"s"
+
str
(
client_id
),
str
(
client_file_id
))
path
=
os
.
path
.
join
(
"s"
+
str
(
client_id
),
str
(
client_file_id
))
# call base class constructor
xbob
.
db
.
verification
.
utils
.
File
.
__init__
(
self
,
client_id
=
client_id
,
file_id
=
file_id
,
path
=
path
)
@
staticmethod
...
...
@@ -67,48 +69,3 @@ class File:
assert
paths
[
1
][
0
]
==
's'
return
File
(
int
(
paths
[
1
][
1
:]),
int
(
file_name
))
def
make_path
(
self
,
directory
=
None
,
extension
=
None
):
"""Wraps the current path so that a complete path is formed
Keyword parameters:
directory
An optional directory name that will be prefixed to the returned result.
extension
An optional extension that will be suffixed to the returned filename. The
extension normally includes the leading ``.`` character as in ``.jpg`` or
``.hdf5``.
Returns a string containing the newly generated file path.
"""
if
not
directory
:
directory
=
''
if
not
extension
:
extension
=
''
return
os
.
path
.
join
(
directory
,
self
.
path
+
extension
)
def
save
(
self
,
data
,
directory
=
None
,
extension
=
'.hdf5'
):
"""Saves the input data at the specified location and using the given
extension.
Keyword parameters:
data
The data blob to be saved (normally a :py:class:`numpy.ndarray`).
directory
If not empty or None, this directory is prefixed to the final file
destination
extension
The extension of the filename - this will control the type of output and
the codec for saving the input blob.
"""
path
=
self
.
make_path
(
directory
,
extension
)
bob
.
utils
.
makedirs_safe
(
os
.
path
.
dirname
(
path
))
bob
.
io
.
save
(
data
,
path
)
xbob/db/atnt/query.py
View file @
98ea7fb6
...
...
@@ -19,28 +19,23 @@
from
.models
import
Client
,
File
class
Database
(
object
):
import
xbob.db.verification.utils
class
Database
(
xbob
.
db
.
verification
.
utils
.
Database
):
"""Wrapper class for the AT&T (aka ORL) database of faces (http://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html).
This class defines a simple protocol for training, enrollment and probe by splitting the few images of the database in a reasonable manner.
Due to the small size of the database, there is only a 'dev' group, and I did not define an 'eval' group."""
def
__init__
(
self
):
# call base class constructor
xbob
.
db
.
verification
.
utils
.
Database
.
__init__
(
self
)
# initialize members
self
.
m_groups
=
(
'world'
,
'dev'
)
self
.
m_purposes
=
(
'enrol'
,
'probe'
)
self
.
m_training_clients
=
set
([
1
,
2
,
5
,
6
,
10
,
11
,
12
,
14
,
16
,
17
,
20
,
21
,
24
,
26
,
27
,
29
,
33
,
34
,
36
,
39
])
self
.
m_enrol_files
=
set
([
2
,
4
,
5
,
7
,
9
])
def
__check_validity__
(
self
,
l
,
obj
,
valid
,
default
):
"""Checks validity of user input data against a set of valid values."""
if
not
l
:
return
default
elif
isinstance
(
l
,
str
)
or
isinstance
(
l
,
int
):
return
self
.
__check_validity__
([
l
],
obj
,
valid
,
default
)
for
k
in
l
:
if
k
not
in
valid
:
raise
RuntimeError
,
'Invalid %s "%s". Valid values are %s, or lists/tuples of those'
%
(
obj
,
k
,
valid
)
return
l
def
clients
(
self
,
groups
=
None
,
protocol
=
None
):
"""Returns the vector of clients used in a given group
...
...
@@ -53,8 +48,7 @@ class Database(object):
Ignored.
"""
VALID_GROUPS
=
self
.
m_groups
groups
=
self
.
__check_validity__
(
groups
,
"group"
,
VALID_GROUPS
,
VALID_GROUPS
)
groups
=
self
.
check_parameters_for_validity
(
groups
,
"group"
,
self
.
m_groups
)
ids
=
set
()
if
'world'
in
groups
:
...
...
@@ -76,8 +70,7 @@ class Database(object):
Ignored.
"""
VALID_GROUPS
=
self
.
m_groups
groups
=
self
.
__check_validity__
(
groups
,
"group"
,
VALID_GROUPS
,
VALID_GROUPS
)
groups
=
self
.
check_parameters_for_validity
(
groups
,
"group"
,
self
.
m_groups
)
ids
=
set
()
if
'world'
in
groups
:
...
...
@@ -150,25 +143,22 @@ class Database(object):
"""
# check if groups set are valid
VALID_GROUPS
=
self
.
m_groups
groups
=
self
.
__check_validity__
(
groups
,
"group"
,
VALID_GROUPS
,
VALID_GROUPS
)
groups
=
self
.
check_parameters_for_validity
(
groups
,
"group"
,
self
.
m_groups
)
# collect the ids to retrieve
ids
=
set
(
self
.
client_ids
(
groups
))
# check the desired client ids for sanity
VALID_IDS
=
Client
.
m_valid_client_ids
model_ids
=
self
.
__check_validity__
(
model_ids
,
"model"
,
VALID_IDS
,
VALID_IDS
)
model_ids
=
self
.
check_parameters_for_validity
(
model_ids
,
"model"
,
list
(
Client
.
m_valid_client_ids
))
# calculate the intersection between the ids and the desired client ids
ids
=
ids
&
set
(
model_ids
)
# check that the groups are valid
VALID_PURPOSES
=
self
.
m_purposes
# check that the purposes are valid
if
'dev'
in
groups
:
purposes
=
self
.
__
check_validity
__
(
purposes
,
"purpose"
,
VALID_PURPOSES
,
VALID_PURPOSES
)
purposes
=
self
.
check_
parameters_for_
validity
(
purposes
,
"purpose"
,
self
.
m_purposes
)
else
:
purposes
=
VALID_PURPOSES
purposes
=
self
.
m_purposes
# go through the dataset and collect all desired files
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment