Commit d0efa448 authored by Philip ABBET's avatar Philip ABBET

Add the 'DatabaseTester' class

parent e9dfbe17
......@@ -33,13 +33,18 @@ import sys
import six
import simplejson
import itertools
import numpy as np
from . import loader
from . import utils
from .dataformat import DataFormat
from .outputs import OutputList
#----------------------------------------------------------
class Storage(utils.CodeStorage):
"""Resolves paths for databases
......@@ -65,6 +70,8 @@ class Storage(utils.CodeStorage):
super(Storage, self).__init__(path, 'python') #views are coded in Python
#----------------------------------------------------------
class View(object):
'''A special loader class for database views, with specialized methods
......@@ -197,6 +204,8 @@ class View(object):
return getattr(self.obj, key)
#----------------------------------------------------------
class Database(object):
"""Databases define the start point of the dataflow in an experiment.
......@@ -376,3 +385,278 @@ class Database(object):
return View(self._module, self.set(protocol, name), self.prefix,
self.data['root_folder'], exc)
#----------------------------------------------------------
class DatabaseTester:
"""Used while developing a new database view, to test its behavior
This class tests that, for each combination of connected/not connected
outputs:
- Data indices seems consistent
- All the connected outputs produce data
- All the not connected outputs don't produce data
It also report some stats, and can generate a text file detailing the
data generated by each output.
By default, outputs are assumed to produce data at constant intervals.
Those that don't follow this pattern, must be declared as 'irregular'.
Note that no particular check is done about the database declaration or
the correctness of the generated data with their data formats. This class
is mainly used to check that the outputs are correctly synchronized.
"""
# Mock output class
class MockOutput:
def __init__(self, name, connected):
self.name = name
self.connected = connected
self.last_written_data_index = -1
self.written_data = []
def write(self, data, end_data_index):
self.written_data.append(( self.last_written_data_index + 1, end_data_index, data ))
self.last_written_data_index = end_data_index
def isConnected(self):
return self.connected
class SynchronizedUnit:
def __init__(self, start, end):
self.start = start
self.end = end
self.data = {}
self.children = []
def addData(self, output, start, end, data):
if (start == self.start) and (end == self.end):
self.data[output] = self._dataToString(data)
elif (len(self.children) == 0) or (self.children[-1].end < start):
unit = DatabaseTester.SynchronizedUnit(start, end)
unit.addData(output, start, end, data)
self.children.append(unit)
else:
for index, unit in enumerate(self.children):
if (unit.start <= start) and (unit.end >= end):
unit.addData(output, start, end, data)
break
elif (unit.start == start) and (unit.end < end):
new_unit = DatabaseTester.SynchronizedUnit(start, end)
new_unit.addData(output, start, end, data)
new_unit.children.append(unit)
for i in range(index + 1, len(self.children)):
unit = self.children[i]
if (unit.end <= end):
new_unit.children.append(unit)
else:
break
self.children = self.children[:index] + [new_unit] + self.children[i:]
break
def toString(self):
texts = {}
for child in self.children:
child_texts = child.toString()
for output, text in child_texts.items():
if texts.has_key(output):
texts[output] += ' ' + text
else:
texts[output] = text
if len(self.data) > 0:
length = max([ len(x) + 6 for x in self.data.values() ])
if len(texts) > 0:
children_length = len(texts.values()[0])
if children_length >= length:
length = children_length
else:
diff = length - children_length
if diff % 2 == 0:
diff1 = diff / 2
diff2 = diff1
else:
diff1 = diff // 2
diff2 = diff - diff1
for k, v in texts.items():
texts[k] = '|%s%s%s|' % ('-' * diff1, v[1:-1], '-' * diff2)
for output, value in self.data.items():
output_length = len(value) + 6
diff = length - output_length
if diff % 2 == 0:
diff1 = diff / 2
diff2 = diff1
else:
diff1 = diff // 2
diff2 = diff - diff1
texts[output] = '|-%s %s %s-|' % ('-' * diff1, value, '-' * diff2)
length = max(len(x) for x in texts.values())
for k, v in texts.items():
if len(v) < length:
texts[k] += ' ' * (length - len(v))
return texts
def _dataToString(self, data):
if (len(data) > 1) or (len(data) == 0):
return 'X'
value = data[data.keys()[0]]
if isinstance(value, np.ndarray) or isinstance(value, dict):
return 'X'
return str(value)
def __init__(self, name, view_class, outputs_declaration, parameters,
irregular_outputs=[], all_combinations=True):
self.name = name
self.view_class = view_class
self.outputs_declaration = {}
self.parameters = parameters
self.irregular_outputs = irregular_outputs
self.determine_regular_intervals(outputs_declaration)
if all_combinations:
for L in range(0, len(self.outputs_declaration) + 1):
for subset in itertools.combinations(self.outputs_declaration.keys(), L):
self.run(subset)
else:
self.run(self.outputs_declaration.keys())
print
def determine_regular_intervals(self, outputs_declaration):
outputs = OutputList()
for name in outputs_declaration:
outputs.add(DatabaseTester.MockOutput(name, True))
view = self.view_class()
view.setup('', outputs, self.parameters)
view.next()
for output in outputs:
if output.name not in self.irregular_outputs:
self.outputs_declaration[output.name] = output.last_written_data_index + 1
else:
self.outputs_declaration[output.name] = None
def run(self, connected_outputs):
if len(connected_outputs) == 0:
return
print "Testing '%s', with %d output(s): %s" % (self.name, len(connected_outputs),
', '.join(connected_outputs))
# Create the mock outputs
connected_outputs = dict([ x for x in self.outputs_declaration.items()
if x[0] in connected_outputs ])
not_connected_outputs = dict([ x for x in self.outputs_declaration.items()
if x[0] not in connected_outputs ])
outputs = OutputList()
for name in self.outputs_declaration.keys():
outputs.add(DatabaseTester.MockOutput(name, name in connected_outputs))
# Create the view
view = self.view_class()
view.setup('', outputs, self.parameters)
# Initialisations
next_expected_indices = {}
for name, interval in connected_outputs.items():
next_expected_indices[name] = 0
next_index = 0
def _done():
for output in outputs:
if output.isConnected() and not view.done(output.last_written_data_index):
return False
return True
# Ask for all the data
while not(_done()):
view.next()
# Check the indices for the connected outputs
for name in connected_outputs.keys():
if name not in self.irregular_outputs:
assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
assert(outputs[name].written_data[-1][1] == next_expected_indices[name] + connected_outputs[name] - 1)
else:
assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
assert(outputs[name].written_data[-1][1] >= next_expected_indices[name])
# Check that the not connected outputs didn't produce data
for name in not_connected_outputs.keys():
assert(len(outputs[name].written_data) == 0)
# Compute the next data index that should be produced
next_index = 1 + min([ x.written_data[-1][1] for x in outputs if x.isConnected() ])
# Compute the next data index that should be produced by each connected output
for name in connected_outputs.keys():
if name not in self.irregular_outputs:
if next_index == next_expected_indices[name] + connected_outputs[name]:
next_expected_indices[name] += connected_outputs[name]
else:
if next_index > outputs[name].written_data[-1][1]:
next_expected_indices[name] = outputs[name].written_data[-1][1] + 1
# Check the number of data produced on the regular outputs
for name in connected_outputs.keys():
print ' - %s: %d data' % (name, len(outputs[name].written_data))
if name not in self.irregular_outputs:
assert(len(outputs[name].written_data) == next_index / connected_outputs[name])
# Check that all outputs ends on the same index
for name in connected_outputs.keys():
assert(outputs[name].written_data[-1][1] == next_index - 1)
# Generate a text file with lots of details (only if all outputs are connected)
if len(connected_outputs) == len(self.outputs_declaration):
sorted_outputs = sorted(outputs, key=lambda x: len(x.written_data))
unit = DatabaseTester.SynchronizedUnit(0, sorted_outputs[0].written_data[-1][1])
for output in sorted_outputs:
for data in output.written_data:
unit.addData(output.name, data[0], data[1], data[2])
texts = unit.toString()
outputs_max_length = max([ len(x) for x in self.outputs_declaration.keys() ])
with open(self.name.replace(' ', '_') + '.txt', 'w') as f:
for i in range(1, len(sorted_outputs) + 1):
output_name = sorted_outputs[-i].name
f.write(output_name + ': ')
if len(output_name) < outputs_max_length:
f.write(' ' * (outputs_max_length - len(output_name)))
f.write(texts[output_name] + '\n')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment