Commit 32b5f73c authored by Philip ABBET's avatar Philip ABBET

Update atnt/4, atvskeystroke/3, avspoof/3 to use the 'DatabaseTester' class of...

Update atnt/4, atvskeystroke/3, avspoof/3 to use the 'DatabaseTester' class of beat.backend.python v1.4.2
parent fb1bf9d4
......@@ -24,6 +24,7 @@
import numpy as np
import bob.io.base
import bob.io.image
import bob.db.atnt
......@@ -490,166 +491,83 @@ class Probes:
#----------------------------------------------------------
# Test the behavior of the views (on fake data)
if __name__ == '__main__':
# Install a mock load method for the images
def setup_tests():
# Install a mock load function for the images
def mock_load(root_folder):
return np.ndarray((92, 112), dtype=np.uint8)
bob.io.base.load = mock_load
# Mock output class
class MockOutput:
def __init__(self, name, connected):
self.name = name
self.connected = connected
self.last_written_data_index = -1
self.written_data = []
def write(self, data, end_data_index):
self.written_data.append(( self.last_written_data_index + 1, end_data_index, data ))
self.last_written_data_index = end_data_index
def isConnected(self):
return self.connected
# Tester utility class
from beat.backend.python.outputs import OutputList
import itertools
class Tester:
def __init__(self, name, view_class, outputs_declaration):
self.name = name
self.view_class = view_class
self.outputs_declaration = {}
self.determine_increments(outputs_declaration)
for L in range(0, len(self.outputs_declaration) + 1):
for subset in itertools.combinations(self.outputs_declaration.keys(), L):
self.run(subset)
def determine_increments(self, outputs_declaration):
outputs = OutputList()
for name in outputs_declaration:
outputs.add(MockOutput(name, True))
parameters = dict()
view = self.view_class()
view.setup('', outputs, parameters)
view.next()
print "View '%s', increments found:" % self.name
for output in outputs:
self.outputs_declaration[output.name] = output.last_written_data_index + 1
print ' - %s: %d' % (output.name, output.last_written_data_index + 1)
def run(self, connected_outputs):
if len(connected_outputs) == 0:
return
print "Testing '%s', with %d output(s): %s" % (self.name, len(connected_outputs),
', '.join(connected_outputs))
connected_outputs = dict([ x for x in self.outputs_declaration.items()
if x[0] in connected_outputs ])
not_connected_outputs = dict([ x for x in self.outputs_declaration.items()
if x[0] not in connected_outputs ])
outputs = OutputList()
for name in self.outputs_declaration.keys():
outputs.add(MockOutput(name, name in connected_outputs))
parameters = dict()
view = self.view_class()
view.setup('', outputs, parameters)
biggest_index_increment = max(connected_outputs.values())
biggest_index_output = [ x[0] for x in connected_outputs.items()
if x[1] == biggest_index_increment ][0]
lowest_index_increment = min(connected_outputs.values())
lowest_index_output = [ x[0] for x in connected_outputs.items()
if x[1] == lowest_index_increment ][0]
next_expected_indices = {}
for name, increment in connected_outputs.items():
next_expected_indices[name] = 0
current_index = 0
while not(view.done(outputs[lowest_index_output].last_written_data_index)):
view.next()
for name in connected_outputs.keys():
assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
assert(outputs[name].written_data[-1][1] == next_expected_indices[name] + connected_outputs[name] - 1)
for name in not_connected_outputs.keys():
assert(len(outputs[name].written_data) == 0)
current_index += lowest_index_increment
for name in connected_outputs.keys():
if current_index == next_expected_indices[name] + connected_outputs[name]:
next_expected_indices[name] += connected_outputs[name]
for name in connected_outputs.keys():
assert(len(outputs[name].written_data) == current_index / connected_outputs[name])
# The actual tests
Tester('Train', Train, [
'client_id',
'file_id',
'image',
])
Tester('Train (with eye centers)', Train, [
'client_id',
'file_id',
'eye_centers',
'image',
])
#----------------------------------------------------------
Tester('Templates', Templates, [
'template_id',
'client_id',
'file_id',
'image',
])
Tester('Templates (with eye centers)', Templates, [
'template_id',
'client_id',
'file_id',
'eye_centers',
'image',
])
# Test the behavior of the views (on fake data)
if __name__ == '__main__':
Tester('Probes', Probes, [
'template_ids',
'client_id',
'probe_id',
'file_id',
'image',
])
Tester('Probes (with eye centers)', Probes, [
'template_ids',
'client_id',
'probe_id',
'file_id',
'eye_centers',
'image',
])
setup_tests()
from beat.backend.python.database import DatabaseTester
DatabaseTester('Train', Train,
[
'client_id',
'file_id',
'image',
],
parameters=dict(),
)
DatabaseTester('Train (with eye centers)', Train,
[
'client_id',
'file_id',
'eye_centers',
'image',
],
parameters=dict(),
)
DatabaseTester('Templates', Templates,
[
'template_id',
'client_id',
'file_id',
'image',
],
parameters=dict(),
)
DatabaseTester('Templates (with eye centers)', Templates,
[
'template_id',
'client_id',
'file_id',
'eye_centers',
'image',
],
parameters=dict(),
)
DatabaseTester('Probes', Probes,
[
'template_ids',
'client_id',
'probe_id',
'file_id',
'image',
],
parameters=dict(),
)
DatabaseTester('Probes (with eye centers)', Probes,
[
'template_ids',
'client_id',
'probe_id',
'file_id',
'eye_centers',
'image',
],
parameters=dict(),
)
......@@ -25,9 +25,9 @@ The AT&T Database of Faces
Changelog
=========
* **Version 4**, 27/Oct/2017:
* **Version 4**, 30/Oct/2017:
- Port to beat.backend.python v1.4.1
- Port to beat.backend.python v1.4.2
* **Version 3**, 20/Jan/2016:
......
......@@ -76,17 +76,17 @@ def get_client_end_index(objs, client_id, client_start_index,
#----------------------------------------------------------
def get_template_end_index(objs, template_id, template_start_index,
start_index, end_index):
template_end_index = template_start_index
def get_value_end_index(objs, value, index_in_tuple, value_start_index,
start_index, end_index):
value_end_index = value_start_index
while template_end_index + 1 <= end_index:
id = objs[template_end_index + 1 - start_index][0]
while value_end_index + 1 <= end_index:
id = objs[value_end_index + 1 - start_index][index_in_tuple]
if id != template_id:
return template_end_index
if id != value:
return value_end_index
template_end_index += 1
value_end_index += 1
return end_index
......@@ -171,10 +171,10 @@ class Templates:
if self.outputs['template_id'].isConnected() and \
self.outputs['template_id'].last_written_data_index < self.next_index:
template_end_index = get_template_end_index(self.objs, template_id,
self.next_index,
self.start_index,
self.end_index)
template_end_index = get_value_end_index(self.objs, template_id, 0,
self.next_index,
self.start_index,
self.end_index)
self.outputs['template_id'].write(
{
......@@ -271,20 +271,32 @@ class Probes:
# Open the database and load the objects to provide via the outputs
self.db = bob.db.atvskeystroke.Database()
template_ids = sorted(self.db.model_ids(groups='eval',
protocol=parameters['protocol']),
template_ids = sorted(self.db.model_ids(protocol=parameters['protocol'],
groups='eval'),
key=lambda x: int(x))
self.objs = []
template_probes = {}
for template_id in template_ids:
objs = sorted(self.db.objects(groups='eval',
protocol=self.parameters['protocol'],
objs = sorted(self.db.objects(protocol=parameters['protocol'],
groups='eval',
purposes='probe',
model_ids=[template_id]),
key=lambda x: (x.client_id, x.id))
key=lambda x: (x.client_id, x.id))
self.objs.extend([ (template_id, obj) for obj in objs ])
template_probes[template_id] = [ p.id for p in objs ]
objs = sorted(self.db.objects(protocol=parameters['protocol'],
groups='eval',
purposes='probe'),
key=lambda x: (x.client_id, x.id))
self.objs = []
for obj in objs:
templates = [ template_id for template_id in template_ids
if obj.id in template_probes[template_id] ]
self.objs.append( (templates, obj) )
self.objs = sorted(self.objs, key=lambda x: (x[0], x[1].client_id, x[1].id))
# Determine the range of indices that must be provided
self.start_index = force_start_index if force_start_index is not None else 0
......@@ -302,23 +314,23 @@ class Probes:
def next(self):
(template_id, obj) = self.objs[self.next_index - self.start_index]
(template_ids, obj) = self.objs[self.next_index - self.start_index]
# Output: template_ids (only provide data when the template_ids change)
if self.outputs['template_ids'].isConnected() and \
self.outputs['template_ids'].last_written_data_index < self.next_index:
template_end_index = get_template_end_index(self.objs, template_id,
self.next_index,
self.start_index,
self.end_index)
template_ids_end_index = get_value_end_index(self.objs, template_ids, 0,
self.next_index,
self.start_index,
self.end_index)
self.outputs['template_ids'].write(
{
'text': [ str(template_id) ]
'text': [ str(x) for x in template_ids ]
},
template_end_index
template_ids_end_index
)
......@@ -377,136 +389,26 @@ class Probes:
#----------------------------------------------------------
# Test the behavior of the views (on fake data)
if __name__ == '__main__':
# Install a mock load method for the keystrokes
def setup_tests():
# Install a mock load function for the keystrokes
def mock_keystroke_reader(filename):
return {}
global keystroke_reader
keystroke_reader = mock_keystroke_reader
# Mock output class
class MockOutput:
def __init__(self, name, connected):
self.name = name
self.connected = connected
self.last_written_data_index = -1
self.written_data = []
def write(self, data, end_data_index):
self.written_data.append(( self.last_written_data_index + 1, end_data_index, data ))
self.last_written_data_index = end_data_index
def isConnected(self):
return self.connected
# Tester utility class
from beat.backend.python.outputs import OutputList
import itertools
class Tester:
def __init__(self, name, view_class, outputs_declaration, parameters,
irregular_outputs=[]):
self.name = name
self.view_class = view_class
self.outputs_declaration = {}
self.parameters = parameters
self.irregular_outputs = irregular_outputs
self.determine_increments(outputs_declaration)
for L in range(0, len(self.outputs_declaration) + 1):
for subset in itertools.combinations(self.outputs_declaration.keys(), L):
self.run(subset)
def determine_increments(self, outputs_declaration):
outputs = OutputList()
for name in outputs_declaration:
outputs.add(MockOutput(name, True))
view = self.view_class()
view.setup('', outputs, self.parameters)
view.next()
print "View '%s', increments found:" % self.name
for output in outputs:
self.outputs_declaration[output.name] = output.last_written_data_index + 1
print ' - %s: %d' % (output.name, output.last_written_data_index + 1)
def run(self, connected_outputs):
if len(connected_outputs) == 0:
return
print "Testing '%s', with %d output(s): %s" % (self.name, len(connected_outputs),
', '.join(connected_outputs))
connected_outputs = dict([ x for x in self.outputs_declaration.items()
if x[0] in connected_outputs ])
not_connected_outputs = dict([ x for x in self.outputs_declaration.items()
if x[0] not in connected_outputs ])
outputs = OutputList()
for name in self.outputs_declaration.keys():
outputs.add(MockOutput(name, name in connected_outputs))
parameters = dict()
view = self.view_class()
view.setup('', outputs, self.parameters)
next_expected_indices = {}
for name, increment in connected_outputs.items():
next_expected_indices[name] = 0
next_index = 0
def _done():
for output in outputs:
if output.isConnected() and not view.done(output.last_written_data_index):
return False
return True
while not(_done()):
view.next()
for name in connected_outputs.keys():
if name not in self.irregular_outputs:
assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
assert(outputs[name].written_data[-1][1] == next_expected_indices[name] + connected_outputs[name] - 1)
else:
assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
assert(outputs[name].written_data[-1][1] >= next_expected_indices[name])
for name in not_connected_outputs.keys():
assert(len(outputs[name].written_data) == 0)
#----------------------------------------------------------
next_index = 1 + min([ x.written_data[-1][1] for x in outputs if x.isConnected() ])
for name in connected_outputs.keys():
if name not in self.irregular_outputs:
if next_index == next_expected_indices[name] + connected_outputs[name]:
next_expected_indices[name] += connected_outputs[name]
else:
if next_index > outputs[name].written_data[-1][1]:
next_expected_indices[name] = outputs[name].written_data[-1][1] + 1
# Test the behavior of the views (on fake data)
if __name__ == '__main__':
for name in connected_outputs.keys():
if name not in self.irregular_outputs:
assert(len(outputs[name].written_data) == next_index / connected_outputs[name])
else:
print " Irregular output '%s': %s" % (name, str([ (x[0], x[1]) for x in outputs[name].written_data ]))
setup_tests()
from beat.backend.python.database import DatabaseTester
# The actual tests
Tester('Templates', Templates,
DatabaseTester('Templates', Templates,
[
'template_id',
'client_id',
......@@ -518,7 +420,7 @@ if __name__ == '__main__':
)
)
Tester('Probes', Probes,
DatabaseTester('Probes', Probes,
[
'template_ids',
'client_id',
......
......@@ -25,9 +25,9 @@ The ATVS-Keystroke Database
Changelog
=========
* **Version 3**, 27/Oct/2017:
* **Version 3**, 30/Oct/2017:
- Port to beat.backend.python v1.4.1
- Port to beat.backend.python v1.4.2
* **Version 2**, 26/Jan/2016:
......
......@@ -656,136 +656,25 @@ class SimpleAntispoofing:
#----------------------------------------------------------
# Test the behavior of the views (on fake data)
if __name__ == '__main__':
# Install a mock read method for the audio files
def setup_tests():
# Install a mock read function for the audio files
def mock_read(filename):
return 44100, np.ndarray((128,))
scipy.io.wavfile.read = mock_read
# Mock output class
class MockOutput:
def __init__(self, name, connected):
self.name = name
self.connected = connected
self.last_written_data_index = -1
self.written_data = []
def write(self, data, end_data_index):
self.written_data.append(( self.last_written_data_index + 1, end_data_index, data ))
self.last_written_data_index = end_data_index
def isConnected(self):
return self.connected
# Tester utility class
from beat.backend.python.outputs import OutputList
import itertools
class Tester:
def __init__(self, name, view_class, outputs_declaration, parameters,
irregular_outputs=[]):
self.name = name
self.view_class = view_class
self.outputs_declaration = {}
self.parameters = parameters
self.irregular_outputs = irregular_outputs
self.determine_increments(outputs_declaration)
for L in range(0, len(self.outputs_declaration) + 1):
for subset in itertools.combinations(self.outputs_declaration.keys(), L):
self.run(subset)
def determine_increments(self, outputs_declaration):
outputs = OutputList()
for name in outputs_declaration:
outputs.add(MockOutput(name, True))
view = self.view_class()
view.setup('', outputs, self.parameters)
view.next()
print "View '%s', increments found:" % self.name
for output in outputs:
self.outputs_declaration[output.name] = output.last_written_data_index + 1
print ' - %s: %d' % (output.name, output.last_written_data_index + 1)
def run(self, connected_outputs):
if len(connected_outputs) == 0:
return
print "Testing '%s', with %d output(s): %s" % (self.name, len(connected_outputs),
', '.join(connected_outputs))
connected_outputs = dict([ x for x in self.outputs_declaration.items()
if x[0] in connected_outputs ])
not_connected_outputs = dict([ x for x in self.outputs_declaration.items()
if x[0] not in connected_outputs ])
outputs = OutputList()
for name in self.outputs_declaration.keys():
outputs.add(MockOutput(name, name in connected_outputs))
parameters = dict()
view = self.view_class()
view.setup('', outputs, self.parameters)
next_expected_indices = {}
for name, increment in connected_outputs.items():
next_expected_indices[name] = 0
next_index = 0
def _done():
for output in outputs:
if output.isConnected() and not view.done(output.last_written_data_index):
return False
return True
while not(_done()):
view.next()
for name in connected_outputs.keys():
if name not in self.irregular_outputs:
assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
assert(outputs[name].written_data[-1][1] == next_expected_indices[name] + connected_outputs[name] - 1)
else:
assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
assert(outputs[name].written_data[-1][1] >= next_expected_indices[name])
for name in not_connected_outputs.keys():
assert(len(outputs[name].written_data) == 0)
#----------------------------------------------------------
next_index = 1 + min([ x.written_data[-1][1] for x in outputs if x.isConnected() ])
for name in connected_outputs.keys():
if name not in self.irregular_outputs:
if next_index == next_expected_indices[name] + connected_outputs[name]:
next_expected_indices[name] += connected_outputs[name]
else:
if next_index > outputs[name].written_data[-1][1]:
next_expected_indices[name] = outputs[name].written_data[-1][1] + 1
# Test the behavior of the views (on fake data)
if __name__ == '__main__':
for name in connected_outputs.keys():
if name not in self.irregular_outputs:
assert(len(outputs[name].written_data) == next_index / connected_outputs[name])
else:
print " Irregular output '%s': %s" % (name, str([ (x[0], x[1]) for x in outputs[name].written_data ]))
setup_tests()
from beat.backend.python.database import DatabaseTester
# The actual tests
Tester('RecognitionTraining', RecognitionTraining,
DatabaseTester('RecognitionTraining', RecognitionTraining,
[
'client_id',
'file_id',
......@@ -800,7 +689,7 @@ if __name__ == '__main__':
]
)
Tester('RecognitionTemplates', RecognitionTemplates,
DatabaseTester('RecognitionTemplates', RecognitionTemplates,
[
'template_id',
'client_id',
......@@ -818,7 +707,7 @@ if __name__ == '__main__':
]
)
Tester('Probes (with probe_id)', Probes,
DatabaseTester('Probes (with probe_id)', Probes,
[
'template_ids',
'client_id',
......@@ -836,7 +725,7 @@ if __name__ == '__main__':