Commit 2417e7b2 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Add an option to filter samples in queries

This allows customization of protocols to get Idiap only files
See tests for an example.
Fixes #3
parent 65ba6e80
Pipeline #37015 passed with stages
in 14 minutes and 33 seconds
......@@ -43,8 +43,12 @@ class Database(bob.bio.base.database.FileListBioDatabase, SwanVideoDatabase):
)
def objects(self, groups=None, protocol=None, purposes=None,
model_ids=None, classes=None, **kwargs):
model_ids=None, classes=None, filter_samples=None, **kwargs):
files = super(Database, self).objects(
groups=groups, protocol=protocol, purposes=purposes,
model_ids=model_ids, classes=classes, **kwargs)
return self.update_files(files)
files = self.update_files(files)
if filter_samples is None:
return files
files = list(filter(filter_samples, files))
return files
......@@ -40,8 +40,12 @@ class Database(FileListPadDatabase, SwanVideoDatabase):
)
def objects(self, groups=None, protocol=None, purposes=None,
model_ids=None, classes=None, **kwargs):
model_ids=None, classes=None, filter_samples=None, **kwargs):
files = super(Database, self).objects(
groups=groups, protocol=protocol, purposes=purposes,
model_ids=model_ids, classes=classes, **kwargs)
return self.update_files(files)
files = self.update_files(files)
if filter_samples is None:
return files
files = list(filter(filter_samples, files))
return files
......@@ -4,7 +4,6 @@
"""Test Units
"""
from .query_bio import Database
import logging
logger = logging.getLogger(__name__)
......@@ -45,69 +44,21 @@ def _test_annotation(db, files):
"missing?", exc_info=True)
# def test_idiap0_voice():
# protocol = 'idiap0-voice-bio'
# db = Database(protocol=protocol)
def test_pad_protocols():
from .query_pad import Database
# files = db.objects(protocol=protocol, groups='world')
# # 20 clients, 8 recordings, (2 devices in session 1 and 1 device in
# # sessions 2-6) == like it is 1 device and 7 sessions
# _test_numbers(files, 20 * 8 * 1 * 7, 20, 8, 2, 6, range(1, 7), ['IDIAP'])
# assert all(int(f.client.id_in_site) < 25 for f in files)
protocol = 'pad_p2_face_f1'
db = Database(protocol=protocol)
# files = db.objects(protocol=protocol, groups='dev', purposes='enroll')
# _test_numbers(files, 15 * 8 * 2 * 1, 15, 8, 2, 1, range(1, 2), ['IDIAP'])
# assert all(int(f.client.id_in_site) >=
# 25 and int(f.client.id_in_site) < 41 for f in files)
bf, pa = db.all_files(groups='train')
assert len(bf) == 750, len(bf)
assert len(pa) == 1251, len(pa)
# files = db.objects(protocol=protocol, groups='dev', purposes='probe')
# _test_numbers(files, 15 * 8 * 1 * 5, 15, 8, 1, 5, range(2, 7), ['IDIAP'])
# assert all(int(f.client.id_in_site) >=
# 25 and int(f.client.id_in_site) < 41 for f in files)
# check the filter argument
def filter_samples(sample):
return "IDIAP" in sample.client_id
# files = db.objects(protocol=protocol, groups='eval', purposes='enroll')
# _test_numbers(files, 15 * 8 * 2 * 1, 15, 8, 2, 1, range(1, 2), ['IDIAP'])
# assert all(int(f.client.id_in_site) >=
# 41 and int(f.client.id_in_site) < 61 for f in files)
# files = db.objects(protocol=protocol, groups='eval', purposes='probe')
# _test_numbers(files, 15 * 8 * 1 * 5, 15, 8, 1, 5, range(2, 7), ['IDIAP'])
# assert all(int(f.client.id_in_site) >=
# 41 and int(f.client.id_in_site) < 61 for f in files)
# model_ids = db.model_ids_with_protocol(groups='world', protocol=protocol)
# assert len(model_ids) == 20, len(model_ids)
# model_ids = db.model_ids_with_protocol(groups='dev', protocol=protocol)
# assert len(model_ids) == 15, len(model_ids)
# model_ids = db.model_ids_with_protocol(groups='eval', protocol=protocol)
# assert len(model_ids) == 15, len(model_ids)
# _test_annotation(db, files)
# def test_grandtest0_voice():
# protocol = 'grandtest0-voice-bio'
# db = Database(protocol=protocol)
# files = db.objects(protocol=protocol, groups='dev', purposes='enroll')
# _test_numbers(files, 56 * 8 * 1 * 1, 56, 8, 1, 1,
# [2], ['IDIAP', 'MPH-FRA'])
# files = db.objects(protocol=protocol, groups='dev', purposes='probe')
# _test_numbers(files, 2640, 56, 8, 2, 5,
# [1, 3, 4, 5, 6], ['IDIAP', 'MPH-FRA'])
# files = db.objects(protocol=protocol, groups='eval', purposes='enroll')
# _test_numbers(files, 94 * 8 * 1 * 1, 94, 8, 1, 1,
# [2], ['MPH-IND', 'NTNU'])
# files = db.objects(protocol=protocol, groups='eval', purposes='probe')
# _test_numbers(files, 4512, 94, 8, 2, 5,
# [1, 3, 4, 5, 6], ['MPH-IND', 'NTNU'])
# model_ids = db.model_ids_with_protocol(groups='dev', protocol=protocol)
# assert len(model_ids) == 56, len(model_ids)
# model_ids = db.model_ids_with_protocol(groups='eval', protocol=protocol)
# assert len(model_ids) == 94, len(model_ids)
# _test_annotation(db, files)
db.all_files_options = dict(filter_samples=filter_samples)
bf, pa = db.all_files(groups='train')
assert len(bf) == 230, len(bf)
assert len(pa) == 391, len(pa)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment