Commit 98ea7fb6 authored by Manuel Günther's avatar Manuel Günther
Browse files

Based database on the novel xbob.db.verification.utils interface; some cleaned up.

parent fbc21880
......@@ -25,6 +25,7 @@ setup(
install_requires=[
'setuptools',
'bob', # base signal proc./machine learning library
'xbob.db.verification.utils' # defines a set of utilities for face verification databases like this one.
],
namespace_packages = [
......
......@@ -25,6 +25,8 @@ with other xbob.db databases.
import os
import bob
import xbob.db.verification.utils
class Client:
"""The clients of this database contain ONLY client ids. Nothing special."""
m_valid_client_ids = set(range(1, 41))
......@@ -35,18 +37,18 @@ class Client:
class File:
class File (xbob.db.verification.utils.File):
"""Files of this database are composed from the client id and a file id."""
m_valid_file_ids = set(range(1, 11))
def __init__(self, client_id, client_file_id):
assert client_file_id in self.m_valid_file_ids
# compute the file id on the fly
self.id = (client_id-1) * len(self.m_valid_file_ids) + client_file_id
# copy client id
self.client_id = client_id
file_id = (client_id-1) * len(self.m_valid_file_ids) + client_file_id
# generate path on the fly
self.path = os.path.join("s" + str(client_id), str(client_file_id))
path = os.path.join("s" + str(client_id), str(client_file_id))
# call base class constructor
xbob.db.verification.utils.File.__init__(self, client_id = client_id, file_id = file_id, path = path)
@staticmethod
......@@ -67,48 +69,3 @@ class File:
assert paths[1][0] == 's'
return File(int(paths[1][1:]), int(file_name))
def make_path(self, directory=None, extension=None):
"""Wraps the current path so that a complete path is formed
Keyword parameters:
directory
An optional directory name that will be prefixed to the returned result.
extension
An optional extension that will be suffixed to the returned filename. The
extension normally includes the leading ``.`` character as in ``.jpg`` or
``.hdf5``.
Returns a string containing the newly generated file path.
"""
if not directory: directory = ''
if not extension: extension = ''
return os.path.join(directory, self.path + extension)
def save(self, data, directory=None, extension='.hdf5'):
"""Saves the input data at the specified location and using the given
extension.
Keyword parameters:
data
The data blob to be saved (normally a :py:class:`numpy.ndarray`).
directory
If not empty or None, this directory is prefixed to the final file
destination
extension
The extension of the filename - this will control the type of output and
the codec for saving the input blob.
"""
path = self.make_path(directory, extension)
bob.utils.makedirs_safe(os.path.dirname(path))
bob.io.save(data, path)
......@@ -19,28 +19,23 @@
from .models import Client, File
class Database(object):
import xbob.db.verification.utils
class Database(xbob.db.verification.utils.Database):
"""Wrapper class for the AT&T (aka ORL) database of faces (http://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html).
This class defines a simple protocol for training, enrollment and probe by splitting the few images of the database in a reasonable manner.
Due to the small size of the database, there is only a 'dev' group, and I did not define an 'eval' group."""
def __init__(self):
# call base class constructor
xbob.db.verification.utils.Database.__init__(self)
# initialize members
self.m_groups = ('world', 'dev')
self.m_purposes = ('enrol', 'probe')
self.m_training_clients = set([1,2,5,6,10,11,12,14,16,17,20,21,24,26,27,29,33,34,36,39])
self.m_enrol_files = set([2,4,5,7,9])
def __check_validity__(self, l, obj, valid, default):
"""Checks validity of user input data against a set of valid values."""
if not l: return default
elif isinstance(l, str) or isinstance(l, int): return self.__check_validity__([l], obj, valid, default)
for k in l:
if k not in valid:
raise RuntimeError, 'Invalid %s "%s". Valid values are %s, or lists/tuples of those' % (obj, k, valid)
return l
def clients(self, groups = None, protocol = None):
"""Returns the vector of clients used in a given group
......@@ -53,8 +48,7 @@ class Database(object):
Ignored.
"""
VALID_GROUPS = self.m_groups
groups = self.__check_validity__(groups, "group", VALID_GROUPS, VALID_GROUPS)
groups = self.check_parameters_for_validity(groups, "group", self.m_groups)
ids = set()
if 'world' in groups:
......@@ -76,8 +70,7 @@ class Database(object):
Ignored.
"""
VALID_GROUPS = self.m_groups
groups = self.__check_validity__(groups, "group", VALID_GROUPS, VALID_GROUPS)
groups = self.check_parameters_for_validity(groups, "group", self.m_groups)
ids = set()
if 'world' in groups:
......@@ -150,25 +143,22 @@ class Database(object):
"""
# check if groups set are valid
VALID_GROUPS = self.m_groups
groups = self.__check_validity__(groups, "group", VALID_GROUPS, VALID_GROUPS)
groups = self.check_parameters_for_validity(groups, "group", self.m_groups)
# collect the ids to retrieve
ids = set(self.client_ids(groups))
# check the desired client ids for sanity
VALID_IDS = Client.m_valid_client_ids
model_ids = self.__check_validity__(model_ids, "model", VALID_IDS, VALID_IDS)
model_ids = self.check_parameters_for_validity(model_ids, "model", list(Client.m_valid_client_ids))
# calculate the intersection between the ids and the desired client ids
ids = ids & set(model_ids)
# check that the groups are valid
VALID_PURPOSES = self.m_purposes
# check that the purposes are valid
if 'dev' in groups:
purposes = self.__check_validity__(purposes, "purpose", VALID_PURPOSES, VALID_PURPOSES)
purposes = self.check_parameters_for_validity(purposes, "purpose", self.m_purposes)
else:
purposes = VALID_PURPOSES
purposes = self.m_purposes
# go through the dataset and collect all desired files
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment