Commit 89b044a0 authored by Manuel Günther's avatar Manuel Günther
Browse files

Modified the AT&T database to return list of File's on a query; updated tests.

parent f4d7e25f
......@@ -23,215 +23,10 @@ recognition and verification algorithms on. It is also known by its former name
"The ORL Database of Faces". You can download the AT&T database from:
http://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html
"""
import os
import sys
import numpy
from bob.db import utils
__all__ = ['Database',]
from .models import File, Client
from .query import Database
class Database(object):
"""Wrapper class for the AT&T (aka ORL) database of faces (http://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html).
This class defines a simple protocol for training, enrollment and probe by splitting the few images of the database in a reasonable manner."""
__all__ = dir()
def __init__(self):
self.m_groups = ('world', 'dev')
self.m_purposes = ('enrol', 'probe')
self.m_client_ids = set(range(1, 41))
self.m_files = set(range(1, 11))
self.m_training_clients = set([1,2,5,6,10,11,12,14,16,17,20,21,24,26,27,29,33,34,36,39])
self.m_enrol_files = set([2,4,5,7,9])
def dbname(self):
"""Calculates my own name automatically."""
return os.path.basename(os.path.dirname(__file__))
def __check_validity__(self, l, obj, valid, default):
"""Checks validity of user input data against a set of valid values."""
if not l: return default
elif isinstance(l, str) or isinstance(l, int): return self.__check_validity__([l], obj, valid, default)
for k in l:
if k not in valid:
raise RuntimeError, 'Invalid %s "%s". Valid values are %s, or lists/tuples of those' % (obj, k, valid)
return l
def __make_path__(self, client_id, file_id, directory, extension):
"""Generates the file name for the given client id and file id of the AT&T database."""
stem = os.path.join("s" + str(client_id), str(file_id))
if not extension: extension = ''
if directory: return os.path.join(directory, stem + extension)
return stem + extension
def clients(self, groups = None, protocol = None):
"""Returns the vector of ids of the clients used in a given purpose
Keyword Parameters:
groups
One of the groups 'world', 'dev' or a tuple with both of them (which is the default).
protocol
Ignored.
"""
VALID_GROUPS = self.m_groups
groups = self.__check_validity__(groups, "group", VALID_GROUPS, VALID_GROUPS)
ids = set()
if 'world' in groups:
ids |= self.m_training_clients
if 'dev' in groups:
ids |= self.m_client_ids - self.m_training_clients
return list(sorted(ids))
def models(self, groups = None, protocol = None):
"""Returns the vector of ids of the models used in a given purpose
Keyword Parameters:
groups
One of the groups 'world', 'dev' or a tuple with both of them (which is the default).
protocol
Ignored.
"""
VALID_GROUPS = self.m_groups
groups = self.__check_validity__(groups, "group", VALID_GROUPS, VALID_GROUPS)
ids = set()
if 'world' in groups:
ids |= self.m_training_clients
if 'dev' in groups:
ids |= self.m_client_ids - self.m_training_clients
return list(sorted(ids))
def get_client_id_from_file_id(self, file_id):
"""Returns the client id from the given image id"""
return (file_id-1) / len(self.m_files) + 1
def objects(self, directory=None, extension=None, model_ids=None, groups=None, purposes=None, protocol=None):
"""Returns a set of objects for the specific query by the user.
Keyword Parameters:
directory
A directory name that will be prepended to the final filepath returned
extension
A filename extension that will be appended to the final filepath returned
model_ids
The ids of the clients whose files need to be retrieved. Should be a list of integral numbers from [1,40]
groups
One of the groups 'world' or 'dev' or a list with both of them (which is the default).
purposes
One of the purposes 'enrol' or 'probe' or a list with both of them (which is the default).
This field is ignored when the group 'train' is selected.
protocol
Ignored.
Returns: A dictionary containing:
* 0: the resolved filenames
* 1: the model id
* 2: the claimed id attached to the model
* 3: the real id
* 4: the "stem" path (basename of the file)
considering allthe filtering criteria. The keys of the dictionary are
unique identities for each file in the BANCA database. Conserve these
numbers if you wish to save processing results later on.
"""
# check if groups set are valid
VALID_GROUPS = self.m_groups
groups = self.__check_validity__(groups, "group", VALID_GROUPS, VALID_GROUPS)
# collect the ids to retrieve
ids = set(self.clients(groups))
# check the desired client ids for sanity
VALID_IDS = self.m_client_ids
model_ids = self.__check_validity__(model_ids, "model", VALID_IDS, VALID_IDS)
# calculate the intersection between the ids and the desired client ids
ids = ids & set(model_ids)
# check that the groups are valid
VALID_PURPOSES = self.m_purposes
if 'dev' in groups:
purposes = self.__check_validity__(purposes, "purpose", VALID_PURPOSES, VALID_PURPOSES)
else:
purposes = VALID_PURPOSES
# go through the dataset and collect all desired files
retval = {}
if 'enrol' in purposes:
for client_id in ids:
for file_id in self.m_enrol_files:
retval[(client_id-1) * len(self.m_files) + file_id] = (
self.__make_path__(client_id, file_id, directory, extension),
client_id,
client_id,
client_id,
(client_id-1) * len(self.m_files) + file_id)
if 'probe' in purposes:
file_ids = self.m_files - self.m_enrol_files
for client_id in self.clients(groups):
for file_id in file_ids:
retval[(client_id-1) * len(self.m_files) + file_id] = (
self.__make_path__(client_id, file_id, directory, extension),
client_id,
client_id,
model_ids[0] if len(model_ids) == 1 else client_id,
(client_id-1) * len(self.m_files) + file_id)
return retval
def files(self, directory=None, extension=None, model_ids=None, groups=None, purposes=None, protocol=None):
"""Returns a set of filenames for the specific query by the user.
Keyword Parameters:
directory
A directory name that will be prepended to the final filepath returned
extension
A filename extension that will be appended to the final filepath returned
model_ids
The ids of the clients whose files need to be retrieved. Should be a list of integral numbers from [1,40]
groups
One of the groups 'world' or 'dev' or a list with both of them (which is the default).
purposes
One of the purposes 'enrol' or 'probe' or a list with both of them (which is the default).
This field is ignored when the group 'train' is selected.
protocol
Ignored.
"""
retval = {}
o = self.objects(directory, extension, model_ids, groups, purposes)
for k,v in o.iteritems():
retval[k] = v[0]
return retval
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Manuel Guenther <Manuel.Guenther@idiap.ch>
# @author: Manuel Guenther <Manuel.Guenther@idiap.ch>
# @date: Fri Apr 20 12:04:44 CEST 2012
#
# Copyright (C) 2011-2012 Idiap Research Institute, Martigny, Switzerland
#
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
......@@ -21,7 +21,7 @@
"""
import os
import sys
import sys
from bob.db.driver import Interface as BaseInterface
def dumplist(args):
......@@ -30,32 +30,32 @@ def dumplist(args):
from .__init__ import Database
db = Database()
r = db.files(directory=args.directory, extension=args.extension, groups=args.groups, purposes=args.purposes)
r = db.objects(groups=args.groups, purposes=args.purposes)
output = sys.stdout
if args.selftest:
from bob.db.utils import null
output = null()
for id, f in r.items():
output.write('%s\n' % (f,))
for f in r:
output.write('%s\n' % f.make_path(directory=args.directory, extension=args.extension))
return 0
def checkfiles(args):
"""Checks the existence of the files based on your criteria."""
"""Checks the existence of the files based on your criteria."""
from .__init__ import Database
db = Database()
r = db.files(directory=args.directory, extension=args.extension)
r = db.objects()
# go through all files, check if they are available
good = {}
bad = {}
for id, f in r.items():
if os.path.exists(f): good[id] = f
else: bad[id] = f
for f in r:
if os.path.exists(f.make_path(directory=args.directory, extension=args.extension)): good[f.id] = f.make_path(directory=args.directory, extension=args.extension)
else: bad[f.id] = f.make_path(directory=args.directory, extension=args.extension)
# report
output = sys.stdout
......@@ -64,7 +64,7 @@ def checkfiles(args):
output = null()
if bad:
for id, f in bad.items():
for f in bad:
output.write('Cannot find file "%s"\n' % (f,))
output.write('%d files (out of %d) were not found at "%s"\n' % \
(len(bad), len(r), args.directory))
......@@ -72,14 +72,14 @@ def checkfiles(args):
return 0
class Interface(BaseInterface):
def name(self):
return 'atnt'
def version(self):
import pkg_resources # part of setuptools
return pkg_resources.require('xbob.db.%s' % self.name())[0].version
def files(self):
from pkg_resources import resource_filename
......@@ -92,7 +92,7 @@ class Interface(BaseInterface):
def add_commands(self, parser):
from . import __doc__ as docs
subparsers = self.setup_parser(parser,
"AT&T/ORL Face database", docs)
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Manuel Guenther <Manuel.Guenther@idiap.ch>
# @date: Wed Oct 17 15:59:25 CEST 2012
#
# Copyright (C) 2011-2012 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
This file defines simple Client and File interfaces that should be comparable
with other xbob.db databases.
"""
import os
import bob
class Client:
"""The clients of this database contain ONLY client ids. Nothing special."""
def __init__(self, client_id):
self.id = client_id
class File:
"""Files of this database are composed from the client id and a file id."""
file_count_per_id = 10
def __init__(self, client_id, client_file_id):
assert client_file_id in range(1, self.file_count_per_id + 1)
# compute the file id on the fly
self.id = (client_id-1) * self.file_count_per_id + client_file_id
# copy client id
self.client_id = client_id
# generate path on the fly
self.path = os.path.join("s" + str(client_id), str(client_file_id))
def make_path(self, directory=None, extension=None):
"""Wraps the current path so that a complete path is formed
Keyword parameters:
directory
An optional directory name that will be prefixed to the returned result.
extension
An optional extension that will be suffixed to the returned filename. The
extension normally includes the leading ``.`` character as in ``.jpg`` or
``.hdf5``.
Returns a string containing the newly generated file path.
"""
if not directory: directory = ''
if not extension: extension = ''
return os.path.join(directory, self.path + extension)
def save(self, data, directory=None, extension='.hdf5'):
"""Saves the input data at the specified location and using the given
extension.
Keyword parameters:
data
The data blob to be saved (normally a :py:class:`numpy.ndarray`).
directory
If not empty or None, this directory is prefixed to the final file
destination
extension
The extension of the filename - this will control the type of output and
the codec for saving the input blob.
"""
path = self.make_path(directory, extension)
bob.utils.makedirs_safe(os.path.dirname(path))
bob.io.save(data, path)
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Manuel Guenther <Manuel.Guenther@idiap.ch>
# @date: Wed Oct 17 15:59:25 CEST 2012
#
# Copyright (C) 2011-2012 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from .models import Client, File
class Database(object):
"""Wrapper class for the AT&T (aka ORL) database of faces (http://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html).
This class defines a simple protocol for training, enrollment and probe by splitting the few images of the database in a reasonable manner.
Due to the small size of the database, there is only a 'dev' group, and I did not define an 'eval' group."""
def __init__(self):
self.m_groups = ('world', 'dev')
self.m_purposes = ('enrol', 'probe')
self.m_client_ids = set(range(1, 41))
self.m_files = set(range(1, 11))
self.m_training_clients = set([1,2,5,6,10,11,12,14,16,17,20,21,24,26,27,29,33,34,36,39])
self.m_enrol_files = set([2,4,5,7,9])
def __check_validity__(self, l, obj, valid, default):
"""Checks validity of user input data against a set of valid values."""
if not l: return default
elif isinstance(l, str) or isinstance(l, int): return self.__check_validity__([l], obj, valid, default)
for k in l:
if k not in valid:
raise RuntimeError, 'Invalid %s "%s". Valid values are %s, or lists/tuples of those' % (obj, k, valid)
return l
def clients(self, groups = None, protocol = None):
"""Returns the vector of clients used in a given group
Keyword Parameters:
groups
One of the groups 'world', 'dev' or a tuple with both of them (which is the default).
protocol
Ignored.
"""
VALID_GROUPS = self.m_groups
groups = self.__check_validity__(groups, "group", VALID_GROUPS, VALID_GROUPS)
ids = set()
if 'world' in groups:
ids |= self.m_training_clients
if 'dev' in groups:
ids |= self.m_client_ids - self.m_training_clients
return [Client(id) for id in ids]
def client_ids(self, groups = None, protocol = None):
"""Returns the vector of ids of the clients used in a given group
Keyword Parameters:
groups
One of the groups 'world', 'dev' or a tuple with both of them (which is the default).
protocol
Ignored.
"""
VALID_GROUPS = self.m_groups
groups = self.__check_validity__(groups, "group", VALID_GROUPS, VALID_GROUPS)
ids = set()
if 'world' in groups:
ids |= self.m_training_clients
if 'dev' in groups:
ids |= self.m_client_ids - self.m_training_clients
return sorted(list(ids))
def models(self, groups = None, protocol = None):
"""Returns the vector of models ( == clients ) used in a given group
Keyword Parameters:
groups
One of the groups 'world', 'dev' or a tuple with both of them (which is the default).
protocol
Ignored.
"""
return self.clients(groups, protocol)
def model_ids(self, groups = None, protocol = None):
"""Returns the vector of ids of the models (i.e., the client ids) used in a given group
Keyword Parameters:
groups
One of the groups 'world', 'dev' or a tuple with both of them (which is the default).
protocol
Ignored.
"""
return self.client_ids(groups, protocol)
def get_client_id_from_file_id(self, file_id):
"""Returns the client id from the given image id"""
return (file_id-1) / len(self.m_files) + 1
def get_client_id_from_model_id(self, model_id):
"""Returns the client id from the given model id"""
return model_id
def objects(self, model_ids=None, groups=None, purposes=None, protocol=None):
"""Returns a set of File objects for the specific query by the user.
Keyword Parameters:
model_ids
The ids of the clients whose files need to be retrieved. Should be a list of integral numbers from [1,40]
groups
One of the groups 'world' or 'dev' or a list with both of them (which is the default).
purposes
One of the purposes 'enrol' or 'probe' or a list with both of them (which is the default).
This field is ignored when the group 'world' is selected.
protocol
Ignored.
Returns: A list of File's considering all the filtering criteria.
"""
# check if groups set are valid
VALID_GROUPS = self.m_groups
groups = self.__check_validity__(groups, "group", VALID_GROUPS, VALID_GROUPS)
# collect the ids to retrieve
ids = set(self.client_ids(groups))
# check the desired client ids for sanity
VALID_IDS = self.m_client_ids
model_ids = self.__check_validity__(model_ids, "model", VALID_IDS, VALID_IDS)
# calculate the intersection between the ids and the desired client ids
ids = ids & set(model_ids)
# check that the groups are valid
VALID_PURPOSES = self.m_purposes
if 'dev' in groups:
purposes = self.__check_validity__(purposes, "purpose", VALID_PURPOSES, VALID_PURPOSES)
else:
purposes = VALID_PURPOSES
# go through the dataset and collect all desired files
retval = []
if 'enrol' in purposes:
for client_id in ids:
for file_id in self.m_enrol_files:
retval.append(File(client_id, file_id))
if 'probe' in purposes:
file_ids = self.m_files - self.m_enrol_files
# for probe, we use all clients of the given groups
for client_id in self.client_ids(groups):
for file_id in file_ids:
retval.append(File(client_id, file_id))
return retval
......@@ -29,20 +29,20 @@ class ATNTDatabaseTest(unittest.TestCase):
def test01_query(self):
db = Database()
f = db.files()
self.assertEqual(len(f.values()), 400) # number of all files in the database
f = db.objects()
self.assertEqual(len(f), 400) # number of all files in the database
f =