Commit 6caf44ab authored by André Anjos's avatar André Anjos 💬
Browse files

Tests are now passing

parent f934dbf1
......@@ -96,7 +96,7 @@ class Database(bob.db.base.SQLiteDatabase):
"""
if 'train' in groups:
if groups and 'train' in groups:
# there are no models in the training set
if len(groups) == 1: return [] #only group required, so return empty
groups = tuple(k for k in groups if k != 'train')
......@@ -113,9 +113,9 @@ class Database(bob.db.base.SQLiteDatabase):
retval = retval.filter(Protocol.name.in_(protocols))
if groups:
filters.append(Model.group.in_(groups))
retval = retval.filter(Model.group.in_(groups))
retval = retval.filter(*filters).distinct().order_by('id')
retval = retval.distinct().order_by('id')
return [k.id for k in retval]
......@@ -243,10 +243,10 @@ class Database(bob.db.base.SQLiteDatabase):
retval = q if retval is None else retval.union(q)
if 'probe' in purposes:
q = self.query(File).join(Probe.file).join(Model.protocol)
q = self.query(File).join(Probe.file).join(Protocol)
q = q.join(Finger).join(Client)
q = q.filter(Probe.group.in_(groups))
q = q.filter(File.protocol.name.in_(protocols))
q = q.filter(Protocol.name.in_(protocols))
q = q.filter(Client.gender.in_(genders))
q = q.filter(Finger.side.in_(sides))
q = q.filter(Finger.name.in_(fingers))
......
......@@ -48,6 +48,7 @@ def db_available(test):
return wrapper
@nose.tools.nottest
@metadata_available
def test_recreate():
......@@ -56,83 +57,58 @@ def test_recreate():
@metadata_available
def test_counts():
def test_basic_queries():
# test whether the correct number of clients is returned
db = Database()
nose.tools.eq_(db.groups(), ('train', 'dev', 'eval'))
protocols = db.protocol_names()
nose.tools.eq_(len(protocols), 1)
assert 'central' in protocols
nose.tools.eq_(db.groups(), ('train', 'dev', 'eval'))
nose.tools.eq_(db.purposes(), ('train', 'enroll', 'probe'))
nose.tools.eq_(db.genders(), ('m', 'f'))
nose.tools.eq_(db.sides(), ('l', 'r'))
nose.tools.eq_(db.fingers(), ('t', 'i', 'm', 'r', 'l'))
@metadata_available
def test_central():
# test whether the correct number of clients is returned
db = Database()
# FDV: 89 subjects * 2 fingers * 5 snapshots * 1 attempt = 890
# IDI: 2 subjects * 6 fingers * 2 snapshots = 48
# Total: 938 images
nose.tools.eq_(len(db.objects(protocol='central', groups='train')), 938)
train_samples = db.objects(protocol='central', groups='train')
nose.tools.eq_(len(train_samples), 938)
# IDI: 50 subjects * 6 fingers * 2 snapshots * 2 attempts = 1200 images
dev_enroll_samples = db.objects(protocol='central', groups='dev',
purposes='enroll')
nose.tools.eq_(len(dev_enroll_samples), 1200)
model_ids = db.model_ids(protocol='central')
nose.tools.eq_(len(dev_enroll_samples), len(model_ids))
# IDI: 50 subjects * 6 fingers * 2 snapshots * 2 attempts * 2 sessions
# = 2400 images
dev_probe_samples = db.objects(protocol='central', groups='dev',
purposes='probe')
nose.tools.eq_(len(dev_probe_samples), 2400)
# filtering by model ids on probes, returns all
nose.tools.eq_(len(db.objects(protocol='central', groups='dev',
purposes='probe', model_ids = model_ids[0])), 2400)
# 1 image per model
# tests that we can filter by model ids
nose.tools.eq_(len(db.objects(protocol='central', groups='dev',
purposes='enroll')), 1200)
# test model ids
model_ids = db.model_ids()
nose.tools.eq_(len(model_ids), 440)
model_ids = db.model_ids(protocol='Nom')
nose.tools.eq_(len(model_ids), 220)
model_ids = db.model_ids(protocol='Fifty')
nose.tools.eq_(len(model_ids), 100)
model_ids = db.model_ids(protocol='B')
nose.tools.eq_(len(model_ids), 216)
model_ids = db.model_ids(protocol='Full')
nose.tools.eq_(len(model_ids), 440)
# test database sizes
nose.tools.eq_(len(db.objects(protocol='Nom', groups='train')), 0)
nose.tools.eq_(len(db.objects(protocol='Nom', groups='dev')), 440)
nose.tools.eq_(len(db.objects(protocol='Nom', groups='dev',
purposes='enroll')), 220)
nose.tools.eq_(len(db.objects(protocol='Nom', groups='dev',
purposes='probe')), 220)
nose.tools.eq_(len(db.objects(protocol='Fifty', groups='train')), 240)
nose.tools.eq_(len(db.objects(protocol='Fifty', groups='dev')), 200)
nose.tools.eq_(len(db.objects(protocol='Fifty', groups='dev',
purposes='enroll')), 100)
nose.tools.eq_(len(db.objects(protocol='Fifty', groups='dev',
purposes='probe')), 100)
nose.tools.eq_(len(db.objects(protocol='B', groups='train')), 224)
nose.tools.eq_(len(db.objects(protocol='B', groups='dev')), 216)
nose.tools.eq_(len(db.objects(protocol='B', groups='dev',
purposes='enroll')), 216)
nose.tools.eq_(len(db.objects(protocol='B', groups='dev',
purposes='probe')), 216)
nose.tools.eq_(len(db.objects(protocol='Full', groups='train')), 0)
nose.tools.eq_(len(db.objects(protocol='Full', groups='dev')), 440)
nose.tools.eq_(len(db.objects(protocol='Full', groups='dev',
purposes='enroll')), 440)
nose.tools.eq_(len(db.objects(protocol='Full', groups='dev',
purposes='probe')), 440)
# make sure that we can filter by model ids
nose.tools.eq_(len(db.objects(protocol='Full', groups='dev',
purposes='enroll', model_ids=model_ids[:10])), 10)
# filtering by model ids on probes, returns all
nose.tools.eq_(len(db.objects(protocol='Full', groups='dev',
purposes='probe', model_ids=model_ids[0])), 440)
# check file ids for train, dev enroll and dev probe are exclusive
assert len(set(train_samples+dev_enroll_samples+dev_probe_samples)) == 4538
@nose.tools.nottest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment