Commit fbc21880 authored by Manuel Günther's avatar Manuel Günther
Browse files

Added functions to make database API consistent with other databases; added...

Added functions to make database API consistent with other databases; added tests for the new functions.
parent 89b044a0
......@@ -30,7 +30,7 @@ def dumplist(args):
from .__init__ import Database
db = Database()
r = db.objects(groups=args.groups, purposes=args.purposes)
r = db.objects(groups=args.groups, purposes=args.purposes, model_ids=args.client)
output = sys.stdout
if args.selftest:
......@@ -48,7 +48,7 @@ def checkfiles(args):
from .__init__ import Database
db = Database()
r = db.objects()
r = db.objects(groups=args.groups, purposes=args.purposes, model_ids=args.client)
# go through all files, check if they are available
good = {}
......@@ -81,10 +81,7 @@ class Interface(BaseInterface):
return pkg_resources.require('xbob.db.%s' % self.name())[0].version
def files(self):
from pkg_resources import resource_filename
raw_files = ('',)
return [resource_filename(__name__, k) for k in raw_files]
return ()
def type(self):
return 'python_integrated'
......@@ -102,10 +99,13 @@ class Interface(BaseInterface):
from argparse import SUPPRESS
from .models import Client
# add the dumplist command
dump_parser = subparsers.add_parser('dumplist', help="Dumps list of files based on your criteria")
dump_parser.add_argument('-d', '--directory', default=None, help="if given, this path will be prepended to every entry returned")
dump_parser.add_argument('-e', '--extension', default=None, help="if given, this extension will be appended to every entry returned")
dump_parser.add_argument('-C', '--client', dest="client", default=None, type=int, help="if given, limits the dump to a particular client (defaults to '%(default)s')", choices=Client.m_valid_client_ids)
dump_parser.add_argument('-g', '--groups', default=None, help="if given, this value will limit the output files to those belonging to a particular group.", choices=db.m_groups)
dump_parser.add_argument('-p', '--purposes', default=None, help="if given, this value will limit the output files to those belonging to a particular purpose.", choices=db.m_purposes)
dump_parser.add_argument('--self-test', dest="selftest", action='store_true', help=SUPPRESS)
......@@ -115,6 +115,9 @@ class Interface(BaseInterface):
check_parser = subparsers.add_parser('checkfiles', help="Check if the files exist, based on your criteria")
check_parser.add_argument('-d', '--directory', required=True, help="The path to the AT&T images")
check_parser.add_argument('-e', '--extension', default=".pgm", help="The extension of the AT&T images default: '.pgm'")
check_parser.add_argument('-C', '--client', dest="client", default=None, type=int, help="if given, limits the test to a particular client (defaults to '%(default)s')", choices=Client.m_valid_client_ids)
check_parser.add_argument('-g', '--groups', default=None, help="if given, this value will limit the tested files to those belonging to a particular group.", choices=db.m_groups)
check_parser.add_argument('-p', '--purposes', default=None, help="if given, this value will limit the tested files to those belonging to a particular purpose.", choices=db.m_purposes)
check_parser.add_argument('--self-test', dest="selftest", default=False, action='store_true', help=SUPPRESS)
check_parser.set_defaults(func=checkfiles) #action
......@@ -27,24 +27,47 @@ import bob
class Client:
"""The clients of this database contain ONLY client ids. Nothing special."""
m_valid_client_ids = set(range(1, 41))
def __init__(self, client_id):
assert client_id in self.m_valid_client_ids
self.id = client_id
class File:
"""Files of this database are composed from the client id and a file id."""
file_count_per_id = 10
m_valid_file_ids = set(range(1, 11))
def __init__(self, client_id, client_file_id):
assert client_file_id in range(1, self.file_count_per_id + 1)
assert client_file_id in self.m_valid_file_ids
# compute the file id on the fly
self.id = (client_id-1) * self.file_count_per_id + client_file_id
self.id = (client_id-1) * len(self.m_valid_file_ids) + client_file_id
# copy client id
self.client_id = client_id
# generate path on the fly
self.path = os.path.join("s" + str(client_id), str(client_file_id))
@staticmethod
def from_file_id(file_id):
"""Returns the File object for a given file_id"""
client_id = (file_id-1) / len(File.m_valid_file_ids) + 1
client_file_id = (file_id-1) % len(File.m_valid_file_ids) + 1
return File(client_id, client_file_id)
@staticmethod
def from_path(path):
"""Returns the File object for a given path"""
# get the last two paths
paths = os.path.split(path)
file_name = os.path.splitext(paths[1])[0]
paths = os.path.split(paths[0])
assert paths[1][0] == 's'
return File(int(paths[1][1:]), int(file_name))
def make_path(self, directory=None, extension=None):
"""Wraps the current path so that a complete path is formed
......
......@@ -27,8 +27,6 @@ class Database(object):
def __init__(self):
self.m_groups = ('world', 'dev')
self.m_purposes = ('enrol', 'probe')
self.m_client_ids = set(range(1, 41))
self.m_files = set(range(1, 11))
self.m_training_clients = set([1,2,5,6,10,11,12,14,16,17,20,21,24,26,27,29,33,34,36,39])
self.m_enrol_files = set([2,4,5,7,9])
......@@ -62,7 +60,7 @@ class Database(object):
if 'world' in groups:
ids |= self.m_training_clients
if 'dev' in groups:
ids |= self.m_client_ids - self.m_training_clients
ids |= Client.m_valid_client_ids - self.m_training_clients
return [Client(id) for id in ids]
......@@ -85,7 +83,7 @@ class Database(object):
if 'world' in groups:
ids |= self.m_training_clients
if 'dev' in groups:
ids |= self.m_client_ids - self.m_training_clients
ids |= Client.m_valid_client_ids - self.m_training_clients
return sorted(list(ids))
......@@ -122,7 +120,7 @@ class Database(object):
def get_client_id_from_file_id(self, file_id):
"""Returns the client id from the given image id"""
return (file_id-1) / len(self.m_files) + 1
return File.from_file_id(file_id).client_id
def get_client_id_from_model_id(self, model_id):
......@@ -159,7 +157,7 @@ class Database(object):
ids = set(self.client_ids(groups))
# check the desired client ids for sanity
VALID_IDS = self.m_client_ids
VALID_IDS = Client.m_valid_client_ids
model_ids = self.__check_validity__(model_ids, "model", VALID_IDS, VALID_IDS)
# calculate the intersection between the ids and the desired client ids
......@@ -181,7 +179,7 @@ class Database(object):
retval.append(File(client_id, file_id))
if 'probe' in purposes:
file_ids = self.m_files - self.m_enrol_files
file_ids = File.m_valid_file_ids - self.m_enrol_files
# for probe, we use all clients of the given groups
for client_id in self.client_ids(groups):
for file_id in file_ids:
......@@ -190,3 +188,41 @@ class Database(object):
return retval
def paths(self, file_ids, prefix='', suffix=''):
"""Returns a full file paths considering particular file ids, a given
directory and an extension
Keyword Parameters:
file_ids
The list of ids of the File objects in the database.
prefix
The bit of path to be prepended to the filename stem
suffix
The extension determines the suffix that will be appended to the filename
stem.
Returns a list (that may be empty) of the fully constructed paths given the
file ids.
"""
files = [File.from_file_id(id) for id in file_ids]
return [f.make_path(prefix, suffix) for f in files]
def reverse(self, paths):
"""Reverses the lookup: from certain stems, returning file ids
Keyword Parameters:
paths
The filename stems I'll query for. This object should be a python
iterable (such as a tuple or list)
Returns a list (that may be empty).
"""
return [File.from_path(p).id for p in paths]
......@@ -67,6 +67,21 @@ class ATNTDatabaseTest(unittest.TestCase):
f2 = db.objects(groups = 'dev', purposes = 'probe')
self.assertEqual(set([x.id for x in f]),set([x.id for x in f2]))
# test the path function
f = db.objects(groups='dev', purposes = 'enrol', model_ids = [7])
ids = [x.id for x in f]
paths = db.paths(ids, 'test', '.tmp')
self.assertEqual(len(f), len(paths))
for path in paths:
parts = os.path.split(path)
self.assertEqual(parts[0], os.path.join('test', 's7'))
self.assertEqual(os.path.splitext(parts[1])[1], '.tmp')
# test the reverse function
tested_ids = db.reverse(paths)
self.assertEqual(ids, tested_ids)
def test02_manage_dumplist_1(self):
from bob.db.script.dbmanage import main
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment