Commit 32b5f73c authored by Philip ABBET's avatar Philip ABBET

Update atnt/4, atvskeystroke/3, avspoof/3 to use the 'DatabaseTester' class of...

Update atnt/4, atvskeystroke/3, avspoof/3 to use the 'DatabaseTester' class of beat.backend.python v1.4.2
parent fb1bf9d4
......@@ -24,6 +24,7 @@
import numpy as np
import bob.io.base
import bob.io.image
import bob.db.atnt
......@@ -490,166 +491,83 @@ class Probes:
#----------------------------------------------------------
# Test the behavior of the views (on fake data)
if __name__ == '__main__':
# Install a mock load method for the images
def setup_tests():
# Install a mock load function for the images
def mock_load(root_folder):
return np.ndarray((92, 112), dtype=np.uint8)
bob.io.base.load = mock_load
# Mock output class
class MockOutput:
def __init__(self, name, connected):
self.name = name
self.connected = connected
self.last_written_data_index = -1
self.written_data = []
def write(self, data, end_data_index):
self.written_data.append(( self.last_written_data_index + 1, end_data_index, data ))
self.last_written_data_index = end_data_index
def isConnected(self):
return self.connected
# Tester utility class
from beat.backend.python.outputs import OutputList
import itertools
class Tester:
def __init__(self, name, view_class, outputs_declaration):
self.name = name
self.view_class = view_class
self.outputs_declaration = {}
self.determine_increments(outputs_declaration)
for L in range(0, len(self.outputs_declaration) + 1):
for subset in itertools.combinations(self.outputs_declaration.keys(), L):
self.run(subset)
def determine_increments(self, outputs_declaration):
outputs = OutputList()
for name in outputs_declaration:
outputs.add(MockOutput(name, True))
parameters = dict()
view = self.view_class()
view.setup('', outputs, parameters)
view.next()
print "View '%s', increments found:" % self.name
for output in outputs:
self.outputs_declaration[output.name] = output.last_written_data_index + 1
print ' - %s: %d' % (output.name, output.last_written_data_index + 1)
def run(self, connected_outputs):
if len(connected_outputs) == 0:
return
print "Testing '%s', with %d output(s): %s" % (self.name, len(connected_outputs),
', '.join(connected_outputs))
connected_outputs = dict([ x for x in self.outputs_declaration.items()
if x[0] in connected_outputs ])
not_connected_outputs = dict([ x for x in self.outputs_declaration.items()
if x[0] not in connected_outputs ])
outputs = OutputList()
for name in self.outputs_declaration.keys():
outputs.add(MockOutput(name, name in connected_outputs))
parameters = dict()
view = self.view_class()
view.setup('', outputs, parameters)
biggest_index_increment = max(connected_outputs.values())
biggest_index_output = [ x[0] for x in connected_outputs.items()
if x[1] == biggest_index_increment ][0]
lowest_index_increment = min(connected_outputs.values())
lowest_index_output = [ x[0] for x in connected_outputs.items()
if x[1] == lowest_index_increment ][0]
next_expected_indices = {}
for name, increment in connected_outputs.items():
next_expected_indices[name] = 0
current_index = 0
while not(view.done(outputs[lowest_index_output].last_written_data_index)):
view.next()
for name in connected_outputs.keys():
assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
assert(outputs[name].written_data[-1][1] == next_expected_indices[name] + connected_outputs[name] - 1)
for name in not_connected_outputs.keys():
assert(len(outputs[name].written_data) == 0)
#----------------------------------------------------------
current_index += lowest_index_increment
for name in connected_outputs.keys():
if current_index == next_expected_indices[name] + connected_outputs[name]:
next_expected_indices[name] += connected_outputs[name]
# Test the behavior of the views (on fake data)
if __name__ == '__main__':
for name in connected_outputs.keys():
assert(len(outputs[name].written_data) == current_index / connected_outputs[name])
setup_tests()
from beat.backend.python.database import DatabaseTester
# The actual tests
Tester('Train', Train, [
DatabaseTester('Train', Train,
[
'client_id',
'file_id',
'image',
])
],
parameters=dict(),
)
Tester('Train (with eye centers)', Train, [
DatabaseTester('Train (with eye centers)', Train,
[
'client_id',
'file_id',
'eye_centers',
'image',
])
],
parameters=dict(),
)
Tester('Templates', Templates, [
DatabaseTester('Templates', Templates,
[
'template_id',
'client_id',
'file_id',
'image',
])
],
parameters=dict(),
)
Tester('Templates (with eye centers)', Templates, [
DatabaseTester('Templates (with eye centers)', Templates,
[
'template_id',
'client_id',
'file_id',
'eye_centers',
'image',
])
],
parameters=dict(),
)
Tester('Probes', Probes, [
DatabaseTester('Probes', Probes,
[
'template_ids',
'client_id',
'probe_id',
'file_id',
'image',
])
],
parameters=dict(),
)
Tester('Probes (with eye centers)', Probes, [
DatabaseTester('Probes (with eye centers)', Probes,
[
'template_ids',
'client_id',
'probe_id',
'file_id',
'eye_centers',
'image',
])
],
parameters=dict(),
)
......@@ -25,9 +25,9 @@ The AT&T Database of Faces
Changelog
=========
* **Version 4**, 27/Oct/2017:
* **Version 4**, 30/Oct/2017:
- Port to beat.backend.python v1.4.1
- Port to beat.backend.python v1.4.2
* **Version 3**, 20/Jan/2016:
......
This diff is collapsed.
......@@ -25,9 +25,9 @@ The ATVS-Keystroke Database
Changelog
=========
* **Version 3**, 27/Oct/2017:
* **Version 3**, 30/Oct/2017:
- Port to beat.backend.python v1.4.1
- Port to beat.backend.python v1.4.2
* **Version 2**, 26/Jan/2016:
......
......@@ -656,136 +656,25 @@ class SimpleAntispoofing:
#----------------------------------------------------------
# Test the behavior of the views (on fake data)
if __name__ == '__main__':
# Install a mock read method for the audio files
def setup_tests():
# Install a mock read function for the audio files
def mock_read(filename):
return 44100, np.ndarray((128,))
scipy.io.wavfile.read = mock_read
# Mock output class
class MockOutput:
def __init__(self, name, connected):
self.name = name
self.connected = connected
self.last_written_data_index = -1
self.written_data = []
def write(self, data, end_data_index):
self.written_data.append(( self.last_written_data_index + 1, end_data_index, data ))
self.last_written_data_index = end_data_index
def isConnected(self):
return self.connected
# Tester utility class
from beat.backend.python.outputs import OutputList
import itertools
class Tester:
def __init__(self, name, view_class, outputs_declaration, parameters,
irregular_outputs=[]):
self.name = name
self.view_class = view_class
self.outputs_declaration = {}
self.parameters = parameters
self.irregular_outputs = irregular_outputs
self.determine_increments(outputs_declaration)
for L in range(0, len(self.outputs_declaration) + 1):
for subset in itertools.combinations(self.outputs_declaration.keys(), L):
self.run(subset)
def determine_increments(self, outputs_declaration):
outputs = OutputList()
for name in outputs_declaration:
outputs.add(MockOutput(name, True))
view = self.view_class()
view.setup('', outputs, self.parameters)
view.next()
print "View '%s', increments found:" % self.name
for output in outputs:
self.outputs_declaration[output.name] = output.last_written_data_index + 1
print ' - %s: %d' % (output.name, output.last_written_data_index + 1)
def run(self, connected_outputs):
if len(connected_outputs) == 0:
return
print "Testing '%s', with %d output(s): %s" % (self.name, len(connected_outputs),
', '.join(connected_outputs))
connected_outputs = dict([ x for x in self.outputs_declaration.items()
if x[0] in connected_outputs ])
not_connected_outputs = dict([ x for x in self.outputs_declaration.items()
if x[0] not in connected_outputs ])
outputs = OutputList()
for name in self.outputs_declaration.keys():
outputs.add(MockOutput(name, name in connected_outputs))
parameters = dict()
view = self.view_class()
view.setup('', outputs, self.parameters)
next_expected_indices = {}
for name, increment in connected_outputs.items():
next_expected_indices[name] = 0
next_index = 0
def _done():
for output in outputs:
if output.isConnected() and not view.done(output.last_written_data_index):
return False
return True
while not(_done()):
view.next()
for name in connected_outputs.keys():
if name not in self.irregular_outputs:
assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
assert(outputs[name].written_data[-1][1] == next_expected_indices[name] + connected_outputs[name] - 1)
else:
assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
assert(outputs[name].written_data[-1][1] >= next_expected_indices[name])
for name in not_connected_outputs.keys():
assert(len(outputs[name].written_data) == 0)
#----------------------------------------------------------
next_index = 1 + min([ x.written_data[-1][1] for x in outputs if x.isConnected() ])
for name in connected_outputs.keys():
if name not in self.irregular_outputs:
if next_index == next_expected_indices[name] + connected_outputs[name]:
next_expected_indices[name] += connected_outputs[name]
else:
if next_index > outputs[name].written_data[-1][1]:
next_expected_indices[name] = outputs[name].written_data[-1][1] + 1
# Test the behavior of the views (on fake data)
if __name__ == '__main__':
for name in connected_outputs.keys():
if name not in self.irregular_outputs:
assert(len(outputs[name].written_data) == next_index / connected_outputs[name])
else:
print " Irregular output '%s': %s" % (name, str([ (x[0], x[1]) for x in outputs[name].written_data ]))
setup_tests()
from beat.backend.python.database import DatabaseTester
# The actual tests
Tester('RecognitionTraining', RecognitionTraining,
DatabaseTester('RecognitionTraining', RecognitionTraining,
[
'client_id',
'file_id',
......@@ -800,7 +689,7 @@ if __name__ == '__main__':
]
)
Tester('RecognitionTemplates', RecognitionTemplates,
DatabaseTester('RecognitionTemplates', RecognitionTemplates,
[
'template_id',
'client_id',
......@@ -818,7 +707,7 @@ if __name__ == '__main__':
]
)
Tester('Probes (with probe_id)', Probes,
DatabaseTester('Probes (with probe_id)', Probes,
[
'template_ids',
'client_id',
......@@ -836,7 +725,7 @@ if __name__ == '__main__':
]
)
Tester('Probes (with attack_id)', Probes,
DatabaseTester('Probes (with attack_id)', Probes,
[
'template_ids',
'client_id',
......@@ -855,7 +744,7 @@ if __name__ == '__main__':
]
)
Tester('SimpleAntispoofing', SimpleAntispoofing,
DatabaseTester('SimpleAntispoofing', SimpleAntispoofing,
[
'class',
'client_id',
......
......@@ -25,9 +25,9 @@ The AVspoof Database
Changelog
=========
* **Version 3**, 27/Oct/2017:
* **Version 3**, 30/Oct/2017:
- Port to beat.backend.python v1.4.1
- Port to beat.backend.python v1.4.2
* **Version 2**, 24/Mar/2016:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment