Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.db.atnt
Commits
89b044a0
Commit
89b044a0
authored
Oct 17, 2012
by
Manuel Günther
Browse files
Modified the AT&T database to return list of File's on a query; updated tests.
parent
f4d7e25f
Changes
5
Hide whitespace changes
Inline
Side-by-side
xbob/db/atnt/__init__.py
View file @
89b044a0
...
...
@@ -23,215 +23,10 @@ recognition and verification algorithms on. It is also known by its former name
"The ORL Database of Faces". You can download the AT&T database from:
http://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html
"""
import
os
import
sys
import
numpy
from
bob.db
import
utils
__all__
=
[
'Database'
,]
from
.models
import
File
,
Client
from
.query
import
Database
class
Database
(
object
):
"""Wrapper class for the AT&T (aka ORL) database of faces (http://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html).
This class defines a simple protocol for training, enrollment and probe by splitting the few images of the database in a reasonable manner."""
__all__
=
dir
()
def
__init__
(
self
):
self
.
m_groups
=
(
'world'
,
'dev'
)
self
.
m_purposes
=
(
'enrol'
,
'probe'
)
self
.
m_client_ids
=
set
(
range
(
1
,
41
))
self
.
m_files
=
set
(
range
(
1
,
11
))
self
.
m_training_clients
=
set
([
1
,
2
,
5
,
6
,
10
,
11
,
12
,
14
,
16
,
17
,
20
,
21
,
24
,
26
,
27
,
29
,
33
,
34
,
36
,
39
])
self
.
m_enrol_files
=
set
([
2
,
4
,
5
,
7
,
9
])
def
dbname
(
self
):
"""Calculates my own name automatically."""
return
os
.
path
.
basename
(
os
.
path
.
dirname
(
__file__
))
def
__check_validity__
(
self
,
l
,
obj
,
valid
,
default
):
"""Checks validity of user input data against a set of valid values."""
if
not
l
:
return
default
elif
isinstance
(
l
,
str
)
or
isinstance
(
l
,
int
):
return
self
.
__check_validity__
([
l
],
obj
,
valid
,
default
)
for
k
in
l
:
if
k
not
in
valid
:
raise
RuntimeError
,
'Invalid %s "%s". Valid values are %s, or lists/tuples of those'
%
(
obj
,
k
,
valid
)
return
l
def
__make_path__
(
self
,
client_id
,
file_id
,
directory
,
extension
):
"""Generates the file name for the given client id and file id of the AT&T database."""
stem
=
os
.
path
.
join
(
"s"
+
str
(
client_id
),
str
(
file_id
))
if
not
extension
:
extension
=
''
if
directory
:
return
os
.
path
.
join
(
directory
,
stem
+
extension
)
return
stem
+
extension
def
clients
(
self
,
groups
=
None
,
protocol
=
None
):
"""Returns the vector of ids of the clients used in a given purpose
Keyword Parameters:
groups
One of the groups 'world', 'dev' or a tuple with both of them (which is the default).
protocol
Ignored.
"""
VALID_GROUPS
=
self
.
m_groups
groups
=
self
.
__check_validity__
(
groups
,
"group"
,
VALID_GROUPS
,
VALID_GROUPS
)
ids
=
set
()
if
'world'
in
groups
:
ids
|=
self
.
m_training_clients
if
'dev'
in
groups
:
ids
|=
self
.
m_client_ids
-
self
.
m_training_clients
return
list
(
sorted
(
ids
))
def
models
(
self
,
groups
=
None
,
protocol
=
None
):
"""Returns the vector of ids of the models used in a given purpose
Keyword Parameters:
groups
One of the groups 'world', 'dev' or a tuple with both of them (which is the default).
protocol
Ignored.
"""
VALID_GROUPS
=
self
.
m_groups
groups
=
self
.
__check_validity__
(
groups
,
"group"
,
VALID_GROUPS
,
VALID_GROUPS
)
ids
=
set
()
if
'world'
in
groups
:
ids
|=
self
.
m_training_clients
if
'dev'
in
groups
:
ids
|=
self
.
m_client_ids
-
self
.
m_training_clients
return
list
(
sorted
(
ids
))
def
get_client_id_from_file_id
(
self
,
file_id
):
"""Returns the client id from the given image id"""
return
(
file_id
-
1
)
/
len
(
self
.
m_files
)
+
1
def
objects
(
self
,
directory
=
None
,
extension
=
None
,
model_ids
=
None
,
groups
=
None
,
purposes
=
None
,
protocol
=
None
):
"""Returns a set of objects for the specific query by the user.
Keyword Parameters:
directory
A directory name that will be prepended to the final filepath returned
extension
A filename extension that will be appended to the final filepath returned
model_ids
The ids of the clients whose files need to be retrieved. Should be a list of integral numbers from [1,40]
groups
One of the groups 'world' or 'dev' or a list with both of them (which is the default).
purposes
One of the purposes 'enrol' or 'probe' or a list with both of them (which is the default).
This field is ignored when the group 'train' is selected.
protocol
Ignored.
Returns: A dictionary containing:
* 0: the resolved filenames
* 1: the model id
* 2: the claimed id attached to the model
* 3: the real id
* 4: the "stem" path (basename of the file)
considering allthe filtering criteria. The keys of the dictionary are
unique identities for each file in the BANCA database. Conserve these
numbers if you wish to save processing results later on.
"""
# check if groups set are valid
VALID_GROUPS
=
self
.
m_groups
groups
=
self
.
__check_validity__
(
groups
,
"group"
,
VALID_GROUPS
,
VALID_GROUPS
)
# collect the ids to retrieve
ids
=
set
(
self
.
clients
(
groups
))
# check the desired client ids for sanity
VALID_IDS
=
self
.
m_client_ids
model_ids
=
self
.
__check_validity__
(
model_ids
,
"model"
,
VALID_IDS
,
VALID_IDS
)
# calculate the intersection between the ids and the desired client ids
ids
=
ids
&
set
(
model_ids
)
# check that the groups are valid
VALID_PURPOSES
=
self
.
m_purposes
if
'dev'
in
groups
:
purposes
=
self
.
__check_validity__
(
purposes
,
"purpose"
,
VALID_PURPOSES
,
VALID_PURPOSES
)
else
:
purposes
=
VALID_PURPOSES
# go through the dataset and collect all desired files
retval
=
{}
if
'enrol'
in
purposes
:
for
client_id
in
ids
:
for
file_id
in
self
.
m_enrol_files
:
retval
[(
client_id
-
1
)
*
len
(
self
.
m_files
)
+
file_id
]
=
(
self
.
__make_path__
(
client_id
,
file_id
,
directory
,
extension
),
client_id
,
client_id
,
client_id
,
(
client_id
-
1
)
*
len
(
self
.
m_files
)
+
file_id
)
if
'probe'
in
purposes
:
file_ids
=
self
.
m_files
-
self
.
m_enrol_files
for
client_id
in
self
.
clients
(
groups
):
for
file_id
in
file_ids
:
retval
[(
client_id
-
1
)
*
len
(
self
.
m_files
)
+
file_id
]
=
(
self
.
__make_path__
(
client_id
,
file_id
,
directory
,
extension
),
client_id
,
client_id
,
model_ids
[
0
]
if
len
(
model_ids
)
==
1
else
client_id
,
(
client_id
-
1
)
*
len
(
self
.
m_files
)
+
file_id
)
return
retval
def
files
(
self
,
directory
=
None
,
extension
=
None
,
model_ids
=
None
,
groups
=
None
,
purposes
=
None
,
protocol
=
None
):
"""Returns a set of filenames for the specific query by the user.
Keyword Parameters:
directory
A directory name that will be prepended to the final filepath returned
extension
A filename extension that will be appended to the final filepath returned
model_ids
The ids of the clients whose files need to be retrieved. Should be a list of integral numbers from [1,40]
groups
One of the groups 'world' or 'dev' or a list with both of them (which is the default).
purposes
One of the purposes 'enrol' or 'probe' or a list with both of them (which is the default).
This field is ignored when the group 'train' is selected.
protocol
Ignored.
"""
retval
=
{}
o
=
self
.
objects
(
directory
,
extension
,
model_ids
,
groups
,
purposes
)
for
k
,
v
in
o
.
iteritems
():
retval
[
k
]
=
v
[
0
]
return
retval
xbob/db/atnt/driver.py
View file @
89b044a0
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Manuel Guenther <Manuel.Guenther@idiap.ch>
# @author: Manuel Guenther <Manuel.Guenther@idiap.ch>
# @date: Fri Apr 20 12:04:44 CEST 2012
#
# Copyright (C) 2011-2012 Idiap Research Institute, Martigny, Switzerland
#
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
...
...
@@ -21,7 +21,7 @@
"""
import
os
import
sys
import
sys
from
bob.db.driver
import
Interface
as
BaseInterface
def
dumplist
(
args
):
...
...
@@ -30,32 +30,32 @@ def dumplist(args):
from
.__init__
import
Database
db
=
Database
()
r
=
db
.
files
(
directory
=
args
.
directory
,
extension
=
args
.
extension
,
groups
=
args
.
groups
,
purposes
=
args
.
purposes
)
r
=
db
.
objects
(
groups
=
args
.
groups
,
purposes
=
args
.
purposes
)
output
=
sys
.
stdout
if
args
.
selftest
:
from
bob.db.utils
import
null
output
=
null
()
for
id
,
f
in
r
.
items
()
:
output
.
write
(
'%s
\n
'
%
(
f
,
))
for
f
in
r
:
output
.
write
(
'%s
\n
'
%
f
.
make_path
(
directory
=
args
.
directory
,
extension
=
args
.
extension
))
return
0
def
checkfiles
(
args
):
"""Checks the existence of the files based on your criteria."""
"""Checks the existence of the files based on your criteria."""
from
.__init__
import
Database
db
=
Database
()
r
=
db
.
files
(
directory
=
args
.
directory
,
extension
=
args
.
extension
)
r
=
db
.
objects
(
)
# go through all files, check if they are available
good
=
{}
bad
=
{}
for
id
,
f
in
r
.
items
()
:
if
os
.
path
.
exists
(
f
):
good
[
id
]
=
f
else
:
bad
[
id
]
=
f
for
f
in
r
:
if
os
.
path
.
exists
(
f
.
make_path
(
directory
=
args
.
directory
,
extension
=
args
.
extension
)
):
good
[
f
.
id
]
=
f
.
make_path
(
directory
=
args
.
directory
,
extension
=
args
.
extension
)
else
:
bad
[
f
.
id
]
=
f
.
make_path
(
directory
=
args
.
directory
,
extension
=
args
.
extension
)
# report
output
=
sys
.
stdout
...
...
@@ -64,7 +64,7 @@ def checkfiles(args):
output
=
null
()
if
bad
:
for
id
,
f
in
bad
.
items
()
:
for
f
in
bad
:
output
.
write
(
'Cannot find file "%s"
\n
'
%
(
f
,))
output
.
write
(
'%d files (out of %d) were not found at "%s"
\n
'
%
\
(
len
(
bad
),
len
(
r
),
args
.
directory
))
...
...
@@ -72,14 +72,14 @@ def checkfiles(args):
return
0
class
Interface
(
BaseInterface
):
def
name
(
self
):
return
'atnt'
def
version
(
self
):
import
pkg_resources
# part of setuptools
return
pkg_resources
.
require
(
'xbob.db.%s'
%
self
.
name
())[
0
].
version
def
files
(
self
):
from
pkg_resources
import
resource_filename
...
...
@@ -92,7 +92,7 @@ class Interface(BaseInterface):
def
add_commands
(
self
,
parser
):
from
.
import
__doc__
as
docs
subparsers
=
self
.
setup_parser
(
parser
,
"AT&T/ORL Face database"
,
docs
)
...
...
xbob/db/atnt/models.py
0 → 100644
View file @
89b044a0
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Manuel Guenther <Manuel.Guenther@idiap.ch>
# @date: Wed Oct 17 15:59:25 CEST 2012
#
# Copyright (C) 2011-2012 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
This file defines simple Client and File interfaces that should be comparable
with other xbob.db databases.
"""
import
os
import
bob
class
Client
:
"""The clients of this database contain ONLY client ids. Nothing special."""
def
__init__
(
self
,
client_id
):
self
.
id
=
client_id
class
File
:
"""Files of this database are composed from the client id and a file id."""
file_count_per_id
=
10
def
__init__
(
self
,
client_id
,
client_file_id
):
assert
client_file_id
in
range
(
1
,
self
.
file_count_per_id
+
1
)
# compute the file id on the fly
self
.
id
=
(
client_id
-
1
)
*
self
.
file_count_per_id
+
client_file_id
# copy client id
self
.
client_id
=
client_id
# generate path on the fly
self
.
path
=
os
.
path
.
join
(
"s"
+
str
(
client_id
),
str
(
client_file_id
))
def
make_path
(
self
,
directory
=
None
,
extension
=
None
):
"""Wraps the current path so that a complete path is formed
Keyword parameters:
directory
An optional directory name that will be prefixed to the returned result.
extension
An optional extension that will be suffixed to the returned filename. The
extension normally includes the leading ``.`` character as in ``.jpg`` or
``.hdf5``.
Returns a string containing the newly generated file path.
"""
if
not
directory
:
directory
=
''
if
not
extension
:
extension
=
''
return
os
.
path
.
join
(
directory
,
self
.
path
+
extension
)
def
save
(
self
,
data
,
directory
=
None
,
extension
=
'.hdf5'
):
"""Saves the input data at the specified location and using the given
extension.
Keyword parameters:
data
The data blob to be saved (normally a :py:class:`numpy.ndarray`).
directory
If not empty or None, this directory is prefixed to the final file
destination
extension
The extension of the filename - this will control the type of output and
the codec for saving the input blob.
"""
path
=
self
.
make_path
(
directory
,
extension
)
bob
.
utils
.
makedirs_safe
(
os
.
path
.
dirname
(
path
))
bob
.
io
.
save
(
data
,
path
)
xbob/db/atnt/query.py
0 → 100644
View file @
89b044a0
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Manuel Guenther <Manuel.Guenther@idiap.ch>
# @date: Wed Oct 17 15:59:25 CEST 2012
#
# Copyright (C) 2011-2012 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from
.models
import
Client
,
File
class
Database
(
object
):
"""Wrapper class for the AT&T (aka ORL) database of faces (http://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html).
This class defines a simple protocol for training, enrollment and probe by splitting the few images of the database in a reasonable manner.
Due to the small size of the database, there is only a 'dev' group, and I did not define an 'eval' group."""
def
__init__
(
self
):
self
.
m_groups
=
(
'world'
,
'dev'
)
self
.
m_purposes
=
(
'enrol'
,
'probe'
)
self
.
m_client_ids
=
set
(
range
(
1
,
41
))
self
.
m_files
=
set
(
range
(
1
,
11
))
self
.
m_training_clients
=
set
([
1
,
2
,
5
,
6
,
10
,
11
,
12
,
14
,
16
,
17
,
20
,
21
,
24
,
26
,
27
,
29
,
33
,
34
,
36
,
39
])
self
.
m_enrol_files
=
set
([
2
,
4
,
5
,
7
,
9
])
def
__check_validity__
(
self
,
l
,
obj
,
valid
,
default
):
"""Checks validity of user input data against a set of valid values."""
if
not
l
:
return
default
elif
isinstance
(
l
,
str
)
or
isinstance
(
l
,
int
):
return
self
.
__check_validity__
([
l
],
obj
,
valid
,
default
)
for
k
in
l
:
if
k
not
in
valid
:
raise
RuntimeError
,
'Invalid %s "%s". Valid values are %s, or lists/tuples of those'
%
(
obj
,
k
,
valid
)
return
l
def
clients
(
self
,
groups
=
None
,
protocol
=
None
):
"""Returns the vector of clients used in a given group
Keyword Parameters:
groups
One of the groups 'world', 'dev' or a tuple with both of them (which is the default).
protocol
Ignored.
"""
VALID_GROUPS
=
self
.
m_groups
groups
=
self
.
__check_validity__
(
groups
,
"group"
,
VALID_GROUPS
,
VALID_GROUPS
)
ids
=
set
()
if
'world'
in
groups
:
ids
|=
self
.
m_training_clients
if
'dev'
in
groups
:
ids
|=
self
.
m_client_ids
-
self
.
m_training_clients
return
[
Client
(
id
)
for
id
in
ids
]
def
client_ids
(
self
,
groups
=
None
,
protocol
=
None
):
"""Returns the vector of ids of the clients used in a given group
Keyword Parameters:
groups
One of the groups 'world', 'dev' or a tuple with both of them (which is the default).
protocol
Ignored.
"""
VALID_GROUPS
=
self
.
m_groups
groups
=
self
.
__check_validity__
(
groups
,
"group"
,
VALID_GROUPS
,
VALID_GROUPS
)
ids
=
set
()
if
'world'
in
groups
:
ids
|=
self
.
m_training_clients
if
'dev'
in
groups
:
ids
|=
self
.
m_client_ids
-
self
.
m_training_clients
return
sorted
(
list
(
ids
))
def
models
(
self
,
groups
=
None
,
protocol
=
None
):
"""Returns the vector of models ( == clients ) used in a given group
Keyword Parameters:
groups
One of the groups 'world', 'dev' or a tuple with both of them (which is the default).
protocol
Ignored.
"""
return
self
.
clients
(
groups
,
protocol
)
def
model_ids
(
self
,
groups
=
None
,
protocol
=
None
):
"""Returns the vector of ids of the models (i.e., the client ids) used in a given group
Keyword Parameters:
groups
One of the groups 'world', 'dev' or a tuple with both of them (which is the default).
protocol
Ignored.
"""
return
self
.
client_ids
(
groups
,
protocol
)
def
get_client_id_from_file_id
(
self
,
file_id
):
"""Returns the client id from the given image id"""
return
(
file_id
-
1
)
/
len
(
self
.
m_files
)
+
1
def
get_client_id_from_model_id
(
self
,
model_id
):
"""Returns the client id from the given model id"""
return
model_id
def
objects
(
self
,
model_ids
=
None
,
groups
=
None
,
purposes
=
None
,
protocol
=
None
):
"""Returns a set of File objects for the specific query by the user.
Keyword Parameters:
model_ids
The ids of the clients whose files need to be retrieved. Should be a list of integral numbers from [1,40]
groups
One of the groups 'world' or 'dev' or a list with both of them (which is the default).
purposes
One of the purposes 'enrol' or 'probe' or a list with both of them (which is the default).
This field is ignored when the group 'world' is selected.
protocol
Ignored.
Returns: A list of File's considering all the filtering criteria.
"""
# check if groups set are valid
VALID_GROUPS
=
self
.
m_groups
groups
=
self
.
__check_validity__
(
groups
,
"group"
,
VALID_GROUPS
,
VALID_GROUPS
)
# collect the ids to retrieve
ids
=
set
(
self
.
client_ids
(
groups
))
# check the desired client ids for sanity
VALID_IDS
=
self
.
m_client_ids
model_ids
=
self
.
__check_validity__
(
model_ids
,
"model"
,
VALID_IDS
,
VALID_IDS
)
# calculate the intersection between the ids and the desired client ids
ids
=
ids
&
set
(
model_ids
)
# check that the groups are valid
VALID_PURPOSES
=
self
.
m_purposes
if
'dev'
in
groups
:
purposes
=
self
.
__check_validity__
(
purposes
,
"purpose"
,
VALID_PURPOSES
,
VALID_PURPOSES
)
else
:
purposes
=
VALID_PURPOSES
# go through the dataset and collect all desired files
retval
=
[]
if
'enrol'
in
purposes
:
for
client_id
in
ids
:
for
file_id
in
self
.
m_enrol_files
:
retval
.
append
(
File
(
client_id
,
file_id
))
if
'probe'
in
purposes
:
file_ids
=
self
.
m_files
-
self
.
m_enrol_files
# for probe, we use all clients of the given groups
for
client_id
in
self
.
client_ids
(
groups
):
for
file_id
in
file_ids
:
retval
.
append
(
File
(
client_id
,
file_id
))
return
retval
xbob/db/atnt/test.py
View file @
89b044a0
...
...
@@ -29,20 +29,20 @@ class ATNTDatabaseTest(unittest.TestCase):
def
test01_query
(
self
):
db
=
Database
()
f
=
db
.
file
s
()
self
.
assertEqual
(
len
(
f
.
values
()
),
400
)
# number of all files in the database
f
=
db
.
object
s
()
self
.
assertEqual
(
len
(
f
),
400
)
# number of all files in the database
f
=