diff --git a/beat/backend/python/database.py b/beat/backend/python/database.py index 6d2474419187b774b3fc320f6bea8d69c390b2c4..9a35200d8b92d5996284f8639241605afe6112b4 100644 --- a/beat/backend/python/database.py +++ b/beat/backend/python/database.py @@ -54,8 +54,10 @@ from collections import namedtuple from . import loader from . import utils +from .protocoltemplate import ProtocolTemplate from .dataformat import DataFormat from .outputs import OutputList +from .exceptions import OutputError # ---------------------------------------------------------- @@ -75,24 +77,24 @@ class Storage(utils.CodeStorage): def __init__(self, prefix, name): - if name.count('/') != 1: + if name.count("/") != 1: raise RuntimeError("invalid database name: `%s'" % name) - self.name, self.version = name.split('/') + self.name, self.version = name.split("/") self.fullname = name self.prefix = prefix - path = os.path.join(self.prefix, 'databases', name + '.json') + path = os.path.join(self.prefix, "databases", name + ".json") path = path[:-5] # views are coded in Python - super(Storage, self).__init__(path, 'python') + super(Storage, self).__init__(path, "python") # ---------------------------------------------------------- class Runner(object): - '''A special loader class for database views, with specialized methods + """A special loader class for database views, with specialized methods Parameters: @@ -113,34 +115,33 @@ class Runner(object): **kwargs: Constructor parameters for the database view. Normally, none. - ''' + """ def __init__(self, module, definition, prefix, root_folder, exc=None): try: - class_ = getattr(module, definition['view']) - except Exception as e: + class_ = getattr(module, definition["view"]) + except Exception: if exc is not None: type, value, traceback = sys.exc_info() six.reraise(exc, exc(value), traceback) else: raise # just re-raise the user exception - self.obj = loader.run(class_, '__new__', exc) - self.ready = False - self.prefix = prefix - self.root_folder = root_folder - self.definition = definition - self.exc = exc or RuntimeError + self.obj = loader.run(class_, "__new__", exc) + self.ready = False + self.prefix = prefix + self.root_folder = root_folder + self.definition = definition + self.exc = exc or RuntimeError self.data_sources = None - def index(self, filename): - '''Index the content of the view''' + """Index the content of the view""" - parameters = self.definition.get('parameters', {}) + parameters = self.definition.get("parameters", {}) - objs = loader.run(self.obj, 'index', self.exc, self.root_folder, parameters) + objs = loader.run(self.obj, "index", self.exc, self.root_folder, parameters) if not isinstance(objs, list): raise self.exc("index() didn't return a list") @@ -148,54 +149,64 @@ class Runner(object): if not os.path.exists(os.path.dirname(filename)): os.makedirs(os.path.dirname(filename)) - with open(filename, 'wb') as f: + with open(filename, "wb") as f: data = simplejson.dumps(objs, cls=utils.NumpyJSONEncoder) - f.write(data.encode('utf-8')) - + f.write(data.encode("utf-8")) def setup(self, filename, start_index=None, end_index=None, pack=True): - '''Sets up the view''' + """Sets up the view""" if self.ready: return - with open(filename, 'rb') as f: - objs = simplejson.loads(f.read().decode('utf-8')) + with open(filename, "rb") as f: + objs = simplejson.loads(f.read().decode("utf-8")) - Entry = namedtuple('Entry', sorted(objs[0].keys())) - objs = [ Entry(**x) for x in objs ] + Entry = namedtuple("Entry", sorted(objs[0].keys())) + objs = [Entry(**x) for x in objs] - parameters = self.definition.get('parameters', {}) - - loader.run(self.obj, 'setup', self.exc, self.root_folder, parameters, - objs, start_index=start_index, end_index=end_index) + parameters = self.definition.get("parameters", {}) + loader.run( + self.obj, + "setup", + self.exc, + self.root_folder, + parameters, + objs, + start_index=start_index, + end_index=end_index, + ) # Create data sources for the outputs from .data import DatabaseOutputDataSource - from .dataformat import DataFormat self.data_sources = {} - for output_name, output_format in self.definition.get('outputs', {}).items(): + for output_name, output_format in self.definition.get("outputs", {}).items(): data_source = DatabaseOutputDataSource() - data_source.setup(self, output_name, output_format, self.prefix, - start_index=start_index, end_index=end_index, pack=pack) + data_source.setup( + self, + output_name, + output_format, + self.prefix, + start_index=start_index, + end_index=end_index, + pack=pack, + ) self.data_sources[output_name] = data_source self.ready = True - def get(self, output, index): - '''Returns the data of the provided output at the provided index''' + """Returns the data of the provided output at the provided index""" if not self.ready: raise self.exc("Database view not yet setup") - return loader.run(self.obj, 'get', self.exc, output, index) - + return loader.run(self.obj, "get", self.exc, output, index) def get_output_mapping(self, output): - return loader.run(self.obj, 'get_output_mapping', self.exc, output) + return loader.run(self.obj, "get_output_mapping", self.exc, output) def objects(self): return self.obj.objs @@ -240,12 +251,43 @@ class Database(object): self.errors = [] self.data = None + self.is_v1 = False # if the user has not provided a cache, still use one for performance dataformat_cache = dataformat_cache if dataformat_cache is not None else {} self._load(name, dataformat_cache) + def _update_dataformat_cache(self, outputs, dataformat_cache): + for key, value in outputs.items(): + + if value in self.dataformats: + continue + + if value in dataformat_cache: + dataformat = dataformat_cache[value] + else: + dataformat = DataFormat(self.prefix, value) + dataformat_cache[value] = dataformat + + self.dataformats[value] = dataformat + + def _load_v1(self, dataformat_cache): + """Loads a v1 database and fills the dataformat cache""" + + for protocol in self.data["protocols"]: + for set_ in protocol["sets"]: + self._update_dataformat_cache(set_["outputs"], dataformat_cache) + + def _load_v2(self, dataformat_cache): + """Loads a v2 database and fills the dataformat cache""" + + for protocol in self.data["protocols"]: + protocol_template = ProtocolTemplate( + self.prefix, protocol["template"], dataformat_cache + ) + for set_ in protocol_template.sets(): + self._update_dataformat_cache(set_["outputs"], dataformat_cache) def _load(self, data, dataformat_cache): """Loads the database""" @@ -255,56 +297,45 @@ class Database(object): self.storage = Storage(self.prefix, self._name) json_path = self.storage.json.path if not self.storage.json.exists(): - self.errors.append('Database declaration file not found: %s' % json_path) + self.errors.append("Database declaration file not found: %s" % json_path) return - with open(json_path, 'rb') as f: - self.data = simplejson.loads(f.read().decode('utf-8')) + with open(json_path, "rb") as f: + self.data = simplejson.loads(f.read().decode("utf-8")) self.code_path = self.storage.code.path self.code = self.storage.code.load() - for protocol in self.data['protocols']: - for _set in protocol['sets']: - - for key, value in _set['outputs'].items(): - - if value in self.dataformats: - continue - - if value in dataformat_cache: - dataformat = dataformat_cache[value] - else: - dataformat = DataFormat(self.prefix, value) - dataformat_cache[value] = dataformat - - self.dataformats[value] = dataformat + schema_version = int(self.data.get("schema_version", 1)) + if schema_version == 1: + self.is_v1 = True + self._load_v1(dataformat_cache) + elif schema_version == 2: + self._load_v2(dataformat_cache) + else: + raise RuntimeError(f"Invalid schema version {schema_version}") @property def name(self): """Returns the name of this object """ - return self._name or '__unnamed_database__' - + return self._name or "__unnamed_database__" @name.setter def name(self, value): self._name = value self.storage = Storage(self.prefix, value) - @property def description(self): """The short description for this object""" - return self.data.get('description', None) - + return self.data.get("description", None) @description.setter def description(self, value): """Sets the short description for this object""" - self.data['description'] = value - + self.data["description"] = value @property def documentation(self): @@ -317,7 +348,6 @@ class Database(object): return self.storage.doc.load() return None - @documentation.setter def documentation(self, value): """Sets the full-length description for this object""" @@ -325,12 +355,11 @@ class Database(object): if not self._name: raise RuntimeError("database has no name") - if hasattr(value, 'read'): + if hasattr(value, "read"): self.storage.doc.save(value.read()) else: self.storage.doc.save(value) - def hash(self): """Returns the hexadecimal hash for its declaration""" @@ -339,12 +368,10 @@ class Database(object): return self.storage.hash() - @property def schema_version(self): """Returns the schema version""" - return self.data.get('schema_version', 1) - + return self.data.get("schema_version", 1) @property def valid(self): @@ -356,43 +383,80 @@ class Database(object): def protocols(self): """The declaration of all the protocols of the database""" - data = self.data['protocols'] - return dict(zip([k['name'] for k in data], data)) - + data = self.data["protocols"] + return dict(zip([k["name"] for k in data], data)) def protocol(self, name): """The declaration of a specific protocol in the database""" return self.protocols[name] - @property def protocol_names(self): """Names of protocols declared for this database""" - data = self.data['protocols'] - return [k['name'] for k in data] - + data = self.data["protocols"] + return [k["name"] for k in data] def sets(self, protocol): """The declaration of a specific set in the database protocol""" - data = self.protocol(protocol)['sets'] - return dict(zip([k['name'] for k in data], data)) + if self.is_v1: + data = self.protocol(protocol)["sets"] + else: + protocol = self.protocol(protocol) + protocol_template = ProtocolTemplate(self.prefix, protocol["template"]) + if not protocol_template.valid: + raise RuntimeError( + "\n * {}".format("\n * ".join(protocol_template.errors)) + ) + data = protocol_template.sets() + return dict(zip([k["name"] for k in data], data)) def set(self, protocol, name): """The declaration of all the protocols of the database""" return self.sets(protocol)[name] - def set_names(self, protocol): """The names of sets in a given protocol for this database""" - data = self.protocol(protocol)['sets'] - return [k['name'] for k in data] + if self.is_v1: + data = self.protocol(protocol)["sets"] + else: + protocol = self.protocol(protocol) + protocol_template = ProtocolTemplate(self.prefix, protocol["template"]) + if not protocol_template.valid: + raise RuntimeError( + "\n * {}".format("\n * ".join(protocol_template.errors)) + ) + data = protocol_template.sets() + + return [k["name"] for k in data] + + def view_definition(self, protocol_name, set_name): + """Returns the definition of a view + Parameters: + protocol_name (str): The name of the protocol where to retrieve the view + from + + set_name (str): The name of the set in the protocol where to retrieve the + view from + + """ + + if self.is_v1: + view_definition = self.set(protocol_name, set_name) + else: + protocol = self.protocol(protocol_name) + template_name = protocol["template"] + protocol_template = ProtocolTemplate(self.prefix, template_name) + view_definition = protocol_template.set(set_name) + view_definition["view"] = protocol["views"][set_name]["view"] + + return view_definition def view(self, protocol, name, exc=None, root_folder=None): """Returns the database view, given the protocol and the set name @@ -421,8 +485,11 @@ class Database(object): raise exc("database has no name") if not self.valid: - message = "cannot load view for set `%s' of protocol `%s' " \ - "from invalid database (%s)" % (protocol, name, self.name) + message = ( + "cannot load view for set `%s' of protocol `%s' " + "from invalid database (%s)\n%s" + % (protocol, name, self.name, " \n".join(self.errors)) + ) if exc: raise exc(message) @@ -431,10 +498,11 @@ class Database(object): # loads the module only once through the lifetime of the database # object try: - if not hasattr(self, '_module'): - self._module = loader.load_module(self.name.replace(os.sep, '_'), - self.storage.code.path, {}) - except Exception as e: + if not hasattr(self, "_module"): + self._module = loader.load_module( + self.name.replace(os.sep, "_"), self.storage.code.path, {} + ) + except Exception: if exc is not None: type, value, traceback = sys.exc_info() six.reraise(exc, exc(value), traceback) @@ -442,11 +510,15 @@ class Database(object): raise # just re-raise the user exception if root_folder is None: - root_folder = self.data['root_folder'] - - return Runner(self._module, self.set(protocol, name), - self.prefix, root_folder, exc) + root_folder = self.data["root_folder"] + return Runner( + self._module, + self.view_definition(protocol, name), + self.prefix, + root_folder, + exc, + ) def json_dumps(self, indent=4): """Dumps the JSON declaration of this object in a string @@ -464,14 +536,11 @@ class Database(object): """ - return simplejson.dumps(self.data, indent=indent, - cls=utils.NumpyJSONEncoder) - + return simplejson.dumps(self.data, indent=indent, cls=utils.NumpyJSONEncoder) def __str__(self): return self.json_dumps() - def write(self, storage=None): """Writes contents to prefix location @@ -490,7 +559,6 @@ class Database(object): storage.save(str(self), self.code, self.description) - def export(self, prefix): """Recursively exports itself into another prefix @@ -520,12 +588,18 @@ class Database(object): raise RuntimeError("database is not valid") if prefix == self.prefix: - raise RuntimeError("Cannot export database to the same prefix (" - "%s)" % prefix) + raise RuntimeError( + "Cannot export database to the same prefix (" "%s)" % prefix + ) for k in self.dataformats.values(): k.export(prefix) + if not self.is_v1: + for protocol in self.protocols.values(): + protocol_template = ProtocolTemplate(self.prefix, protocol["template"]) + protocol_template.export(prefix) + self.write(Storage(prefix, self.name)) @@ -533,8 +607,6 @@ class Database(object): class View(object): - - def __init__(self): # Current databases definitions uses named tuple to store information. # This has one limitation, python keywords like `class` cannot be used. @@ -587,7 +659,6 @@ class View(object): raise NotImplementedError - def setup(self, root_folder, parameters, objs, start_index=None, end_index=None): # Initialisation @@ -599,8 +670,7 @@ class View(object): self.start_index = start_index if start_index is not None else 0 self.end_index = end_index if end_index is not None else len(self.objs) - 1 - self.objs = self.objs[self.start_index : self.end_index + 1] - + self.objs = self.objs[self.start_index : self.end_index + 1] # noqa def get(self, output, index): """Returns the data of the provided output at the provided index in the @@ -610,13 +680,13 @@ class View(object): raise NotImplementedError - def get_output_mapping(self, output): """Returns the object member to use for given output if any otherwise the member name is the output name. """ return self.output_member_map.get(output, output) + # ---------------------------------------------------------- @@ -643,7 +713,6 @@ class DatabaseTester: # Mock output class class MockOutput: - def __init__(self, name, connected): self.name = name self.connected = connected @@ -651,15 +720,15 @@ class DatabaseTester: self.written_data = [] def write(self, data, end_data_index): - self.written_data.append(( self.last_written_data_index + 1, end_data_index, data )) + self.written_data.append( + (self.last_written_data_index + 1, end_data_index, data) + ) self.last_written_data_index = end_data_index def isConnected(self): return self.connected - class SynchronizedUnit: - def __init__(self, start, end): self.start = start self.end = end @@ -685,12 +754,14 @@ class DatabaseTester: for i in range(index + 1, len(self.children)): unit = self.children[i] - if (unit.end <= end): + if unit.end <= end: new_unit.children.append(unit) else: break - self.children = self.children[:index] + [new_unit] + self.children[i:] + self.children = ( + self.children[:index] + [new_unit] + self.children[i:] + ) break def toString(self): @@ -700,12 +771,12 @@ class DatabaseTester: child_texts = child.toString() for output, text in child_texts.items(): if output in texts: - texts[output] += ' ' + text + texts[output] += " " + text else: texts[output] = text if len(self.data) > 0: - length = max([ len(x) + 6 for x in self.data.values() ]) + length = max([len(x) + 6 for x in self.data.values()]) if len(texts) > 0: children_length = len(texts.values()[0]) @@ -721,7 +792,7 @@ class DatabaseTester: diff2 = diff - diff1 for k, v in texts.items(): - texts[k] = '|%s%s%s|' % ('-' * diff1, v[1:-1], '-' * diff2) + texts[k] = "|%s%s%s|" % ("-" * diff1, v[1:-1], "-" * diff2) for output, value in self.data.items(): output_length = len(value) + 6 @@ -732,29 +803,35 @@ class DatabaseTester: else: diff1 = diff // 2 diff2 = diff - diff1 - texts[output] = '|-%s %s %s-|' % ('-' * diff1, value, '-' * diff2) + texts[output] = "|-%s %s %s-|" % ("-" * diff1, value, "-" * diff2) length = max(len(x) for x in texts.values()) for k, v in texts.items(): if len(v) < length: - texts[k] += ' ' * (length - len(v)) + texts[k] += " " * (length - len(v)) return texts def _dataToString(self, data): if (len(data) > 1) or (len(data) == 0): - return 'X' + return "X" value = data[data.keys()[0]] if isinstance(value, np.ndarray) or isinstance(value, dict): - return 'X' + return "X" return str(value) - - def __init__(self, name, view_class, outputs_declaration, parameters, - irregular_outputs=[], all_combinations=True): + def __init__( + self, + name, + view_class, + outputs_declaration, + parameters, + irregular_outputs=[], + all_combinations=True, + ): self.name = name self.view_class = view_class self.outputs_declaration = {} @@ -765,52 +842,60 @@ class DatabaseTester: if all_combinations: for L in range(0, len(self.outputs_declaration) + 1): - for subset in itertools.combinations(self.outputs_declaration.keys(), L): + for subset in itertools.combinations( + self.outputs_declaration.keys(), L + ): self.run(subset) else: self.run(self.outputs_declaration.keys()) - def determine_regular_intervals(self, outputs_declaration): outputs = OutputList() for name in outputs_declaration: outputs.add(DatabaseTester.MockOutput(name, True)) view = self.view_class() - view.setup('', outputs, self.parameters) + view.setup("", outputs, self.parameters) view.next() for output in outputs: if output.name not in self.irregular_outputs: - self.outputs_declaration[output.name] = output.last_written_data_index + 1 + self.outputs_declaration[output.name] = ( + output.last_written_data_index + 1 + ) else: self.outputs_declaration[output.name] = None - def run(self, connected_outputs): if len(connected_outputs) == 0: return - print("Testing '%s', with %d output(s): %s" % \ - (self.name, len(connected_outputs), ', '.join(connected_outputs))) + print( + "Testing '%s', with %d output(s): %s" + % (self.name, len(connected_outputs), ", ".join(connected_outputs)) + ) # Create the mock outputs - connected_outputs = dict([ x for x in self.outputs_declaration.items() - if x[0] in connected_outputs ]) - - not_connected_outputs = dict([ x for x in self.outputs_declaration.items() - if x[0] not in connected_outputs ]) + connected_outputs = dict( + [x for x in self.outputs_declaration.items() if x[0] in connected_outputs] + ) + + not_connected_outputs = dict( + [ + x + for x in self.outputs_declaration.items() + if x[0] not in connected_outputs + ] + ) outputs = OutputList() for name in self.outputs_declaration.keys(): outputs.add(DatabaseTester.MockOutput(name, name in connected_outputs)) - # Create the view view = self.view_class() - view.setup('', outputs, self.parameters) - + view.setup("", outputs, self.parameters) # Initialisations next_expected_indices = {} @@ -821,58 +906,86 @@ class DatabaseTester: def _done(): for output in outputs: - if output.isConnected() and not view.done(output.last_written_data_index): + if output.isConnected() and not view.done( + output.last_written_data_index + ): return False return True - # Ask for all the data - while not(_done()): + while not (_done()): view.next() # Check the indices for the connected outputs for name in connected_outputs.keys(): if name not in self.irregular_outputs: - assert(outputs[name].written_data[-1][0] == next_expected_indices[name]) - assert(outputs[name].written_data[-1][1] == next_expected_indices[name] + connected_outputs[name] - 1) + if not ( + outputs[name].written_data[-1][0] == next_expected_indices[name] + ): + raise OutputError("Wrong current index") + if not ( + outputs[name].written_data[-1][1] + == next_expected_indices[name] + connected_outputs[name] - 1 + ): + raise OutputError("Wrong next index") else: - assert(outputs[name].written_data[-1][0] == next_expected_indices[name]) - assert(outputs[name].written_data[-1][1] >= next_expected_indices[name]) + if not ( + outputs[name].written_data[-1][0] == next_expected_indices[name] + ): + raise OutputError("Wrong current index") + if not ( + outputs[name].written_data[-1][1] >= next_expected_indices[name] + ): + raise OutputError("Wrong next index") # Check that the not connected outputs didn't produce data for name in not_connected_outputs.keys(): - assert(len(outputs[name].written_data) == 0) + if len(outputs[name].written_data) != 0: + raise OutputError("Data written on unconnected output") # Compute the next data index that should be produced - next_index = 1 + min([ x.written_data[-1][1] for x in outputs if x.isConnected() ]) + next_index = 1 + min( + [x.written_data[-1][1] for x in outputs if x.isConnected()] + ) # Compute the next data index that should be produced by each # connected output for name in connected_outputs.keys(): if name not in self.irregular_outputs: - if next_index == next_expected_indices[name] + connected_outputs[name]: + if ( + next_index + == next_expected_indices[name] + connected_outputs[name] + ): next_expected_indices[name] += connected_outputs[name] else: if next_index > outputs[name].written_data[-1][1]: - next_expected_indices[name] = outputs[name].written_data[-1][1] + 1 + next_expected_indices[name] = ( + outputs[name].written_data[-1][1] + 1 + ) # Check the number of data produced on the regular outputs for name in connected_outputs.keys(): - print(' - %s: %d data' % (name, len(outputs[name].written_data))) + print(" - %s: %d data" % (name, len(outputs[name].written_data))) if name not in self.irregular_outputs: - assert(len(outputs[name].written_data) == next_index / connected_outputs[name]) + if not ( + len(outputs[name].written_data) + == next_index / connected_outputs[name] + ): + raise OutputError("Invalid number of data produced") # Check that all outputs ends on the same index for name in connected_outputs.keys(): - assert(outputs[name].written_data[-1][1] == next_index - 1) - + if not outputs[name].written_data[-1][1] == next_index - 1: + raise OutputError("Outputs not on same index") # Generate a text file with lots of details (only if all outputs are # connected) if len(connected_outputs) == len(self.outputs_declaration): sorted_outputs = sorted(outputs, key=lambda x: len(x.written_data)) - unit = DatabaseTester.SynchronizedUnit(0, sorted_outputs[0].written_data[-1][1]) + unit = DatabaseTester.SynchronizedUnit( + 0, sorted_outputs[0].written_data[-1][1] + ) for output in sorted_outputs: for data in output.written_data: @@ -880,14 +993,14 @@ class DatabaseTester: texts = unit.toString() - outputs_max_length = max([ len(x) for x in self.outputs_declaration.keys() ]) + outputs_max_length = max([len(x) for x in self.outputs_declaration.keys()]) - with open(self.name.replace(' ', '_') + '.txt', 'w') as f: + with open(self.name.replace(" ", "_") + ".txt", "w") as f: for i in range(1, len(sorted_outputs) + 1): output_name = sorted_outputs[-i].name - f.write(output_name + ': ') + f.write(output_name + ": ") if len(output_name) < outputs_max_length: - f.write(' ' * (outputs_max_length - len(output_name))) + f.write(" " * (outputs_max_length - len(output_name))) - f.write(texts[output_name] + '\n') + f.write(texts[output_name] + "\n") diff --git a/beat/backend/python/exceptions.py b/beat/backend/python/exceptions.py index 89c9c3fbf347e7f4326cdbeb948b4cfe77f5b23a..8d654367715cfa421d37661c16225fd8e1670dbc 100644 --- a/beat/backend/python/exceptions.py +++ b/beat/backend/python/exceptions.py @@ -42,24 +42,25 @@ exceptions Custom exceptions """ + class RemoteException(Exception): """Exception happening on a remote location""" def __init__(self, kind, message): super(RemoteException, self).__init__() - if kind == 'sys': + if kind == "sys": self.system_error = message - self.user_error = '' + self.user_error = "" else: - self.system_error = '' + self.system_error = "" self.user_error = message def __str__(self): if self.system_error: - return '(sys) {}'.format(self.system_error) + return "(sys) {}".format(self.system_error) else: - return '(usr) {}'.format(self.user_error) + return "(usr) {}".format(self.user_error) class UserError(Exception): @@ -70,3 +71,9 @@ class UserError(Exception): def __str__(self): return repr(self.value) + + +class OutputError(Exception): + """Error happening on output""" + + pass diff --git a/beat/backend/python/test/prefix/databases/integers_db/2.json b/beat/backend/python/test/prefix/databases/integers_db/2.json new file mode 100644 index 0000000000000000000000000000000000000000..c48d561613c2a84d558fe739c413ff8f46647183 --- /dev/null +++ b/beat/backend/python/test/prefix/databases/integers_db/2.json @@ -0,0 +1,55 @@ +{ + "schema_version": 2, + + "root_folder": "/tmp/path/not/set", + "protocols": [ + { + "name": "double", + "template": "double/1", + "views": { + "double": { + "view": "Double" + } + } + }, + { + "name": "triple", + "template": "triple/1", + "views": { + "triple": { + "view": "Triple" + } + } + }, + { + "name": "two_sets", + "template": "two_sets/1", + "views": { + "double": { + "view": "Double" + }, + "triple": { + "view": "Triple" + } + } + }, + { + "name": "labelled", + "template": "labelled/1", + "views": { + "labelled": { + "view": "Labelled" + } + } + }, + { + "name": "different_frequencies", + "template": "different_frequencies/1", + "views": { + "double" : { + "view": "DifferentFrequencies" + } + } + } + ] +} diff --git a/beat/backend/python/test/prefix/databases/integers_db/2.py b/beat/backend/python/test/prefix/databases/integers_db/2.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebebd6c4debec559f8a853830073e92a8b71bd5 --- /dev/null +++ b/beat/backend/python/test/prefix/databases/integers_db/2.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +################################################################################### +# # +# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/ # +# Contact: beat.support@idiap.ch # +# # +# Redistribution and use in source and binary forms, with or without # +# modification, are permitted provided that the following conditions are met: # +# # +# 1. Redistributions of source code must retain the above copyright notice, this # +# list of conditions and the following disclaimer. # +# # +# 2. Redistributions in binary form must reproduce the above copyright notice, # +# this list of conditions and the following disclaimer in the documentation # +# and/or other materials provided with the distribution. # +# # +# 3. Neither the name of the copyright holder nor the names of its contributors # +# may be used to endorse or promote products derived from this software without # +# specific prior written permission. # +# # +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # +# # +################################################################################### + + +import numpy +from collections import namedtuple +from beat.backend.python.database import View + + +class Double(View): + def index(self, root_folder, parameters): + Entry = namedtuple("Entry", ["a", "b", "sum"]) + + return [ + Entry(1, 10, 11), + Entry(2, 20, 22), + Entry(3, 30, 33), + Entry(4, 40, 44), + Entry(5, 50, 55), + Entry(6, 60, 66), + Entry(7, 70, 77), + Entry(8, 80, 88), + Entry(9, 90, 99), + ] + + def get(self, output, index): + obj = self.objs[index] + + if output == "a": + return {"value": numpy.int32(obj.a)} + + elif output == "b": + return {"value": numpy.int32(obj.b)} + + elif output == "sum": + return {"value": numpy.int32(obj.sum)} + elif output == "class": + return {"value": numpy.int32(obj.cls)} + + +# ---------------------------------------------------------- + + +class Triple(View): + def index(self, root_folder, parameters): + Entry = namedtuple("Entry", ["a", "b", "c", "sum"]) + + return [ + Entry(1, 10, 100, 111), + Entry(2, 20, 200, 222), + Entry(3, 30, 300, 333), + Entry(4, 40, 400, 444), + Entry(5, 50, 500, 555), + Entry(6, 60, 600, 666), + Entry(7, 70, 700, 777), + Entry(8, 80, 800, 888), + Entry(9, 90, 900, 999), + ] + + def get(self, output, index): + obj = self.objs[index] + + if output == "a": + return {"value": numpy.int32(obj.a)} + + elif output == "b": + return {"value": numpy.int32(obj.b)} + + elif output == "c": + return {"value": numpy.int32(obj.c)} + + elif output == "sum": + return {"value": numpy.int32(obj.sum)} + + +# ---------------------------------------------------------- + + +class Labelled(View): + def index(self, root_folder, parameters): + Entry = namedtuple("Entry", ["label", "value"]) + + return [ + Entry("A", 1), + Entry("A", 2), + Entry("A", 3), + Entry("A", 4), + Entry("A", 5), + Entry("B", 10), + Entry("B", 20), + Entry("B", 30), + Entry("B", 40), + Entry("B", 50), + Entry("C", 100), + Entry("C", 200), + Entry("C", 300), + Entry("C", 400), + Entry("C", 500), + ] + + def get(self, output, index): + obj = self.objs[index] + + if output == "label": + return {"value": obj.label} + + elif output == "value": + return {"value": numpy.int32(obj.value)} + + +# ---------------------------------------------------------- + + +class DifferentFrequencies(View): + def index(self, root_folder, parameters): + Entry = namedtuple("Entry", ["a", "b"]) + + return [ + Entry(1, 10), + Entry(1, 20), + Entry(1, 30), + Entry(1, 40), + Entry(2, 50), + Entry(2, 60), + Entry(2, 70), + Entry(2, 80), + ] + + def get(self, output, index): + obj = self.objs[index] + + if output == "a": + return {"value": numpy.int32(obj.a)} + + elif output == "b": + return {"value": numpy.int32(obj.b)} diff --git a/beat/backend/python/test/prefix/protocoltemplates/different_frequencies/1.json b/beat/backend/python/test/prefix/protocoltemplates/different_frequencies/1.json index fa1659cc6985c7b6e2c684717fc087ef3221103a..6b7e871992ab86e82e9fe40842460cebdaf56f11 100644 --- a/beat/backend/python/test/prefix/protocoltemplates/different_frequencies/1.json +++ b/beat/backend/python/test/prefix/protocoltemplates/different_frequencies/1.json @@ -3,8 +3,6 @@ "sets": [ { "name": "double", - "template": "double", - "view": "DifferentFrequencies", "outputs": { "a": "user/single_integer/1", "b": "user/single_integer/1" diff --git a/beat/backend/python/test/prefix/protocoltemplates/double/1.json b/beat/backend/python/test/prefix/protocoltemplates/double/1.json index 7cd58ddb2c3a8b10963bb7901b16dd79954799ab..1ad385fa959fcdc3e159d8364c0a3f99c6fc69d6 100644 --- a/beat/backend/python/test/prefix/protocoltemplates/double/1.json +++ b/beat/backend/python/test/prefix/protocoltemplates/double/1.json @@ -3,8 +3,6 @@ "sets": [ { "name": "double", - "template": "double", - "view": "Double", "outputs": { "a": "user/single_integer/1", "b": "user/single_integer/1", diff --git a/beat/backend/python/test/prefix/protocoltemplates/labelled/1.json b/beat/backend/python/test/prefix/protocoltemplates/labelled/1.json index ac3b93693c2d87af9a9342d5fead51892d08f625..ce7dc834988ef894df979f61685a563dedc15924 100644 --- a/beat/backend/python/test/prefix/protocoltemplates/labelled/1.json +++ b/beat/backend/python/test/prefix/protocoltemplates/labelled/1.json @@ -3,8 +3,6 @@ "sets": [ { "name": "labelled", - "template": "labelled", - "view": "Labelled", "outputs": { "value": "user/single_integer/1", "label": "user/single_string/1" diff --git a/beat/backend/python/test/prefix/protocoltemplates/triple/1.json b/beat/backend/python/test/prefix/protocoltemplates/triple/1.json index f5b67ab9e424624ae9cc76f47fbe426cb007eab3..cb6db0360dc213001ec738310594579fb344ff75 100644 --- a/beat/backend/python/test/prefix/protocoltemplates/triple/1.json +++ b/beat/backend/python/test/prefix/protocoltemplates/triple/1.json @@ -3,8 +3,6 @@ "sets": [ { "name": "triple", - "view": "Triple", - "template": "triple", "outputs": { "a": "user/single_integer/1", "b": "user/single_integer/1", diff --git a/beat/backend/python/test/prefix/protocoltemplates/two_sets/1.json b/beat/backend/python/test/prefix/protocoltemplates/two_sets/1.json index 0daf32c2007c05a19756608989076dcefc5e8987..a721319097c23b63491a119e6c7b01592726dc80 100644 --- a/beat/backend/python/test/prefix/protocoltemplates/two_sets/1.json +++ b/beat/backend/python/test/prefix/protocoltemplates/two_sets/1.json @@ -3,8 +3,6 @@ "sets": [ { "name": "double", - "template": "double", - "view": "Double", "outputs": { "a": "user/single_integer/1", "b": "user/single_integer/1", @@ -13,8 +11,6 @@ }, { "name": "triple", - "template": "triple", - "view": "Triple", "outputs": { "a": "user/single_integer/1", "b": "user/single_integer/1", diff --git a/beat/backend/python/test/test_database.py b/beat/backend/python/test/test_database.py index c7c1352594727f0b6f8cf87fa00b5fdd2c01b9e6..3c87f003723850b60fd4f25df7299fcf5954a5f1 100644 --- a/beat/backend/python/test/test_database.py +++ b/beat/backend/python/test/test_database.py @@ -40,74 +40,92 @@ from ..database import Database from . import prefix -#---------------------------------------------------------- +INTEGERS_DBS = ["integers_db/{}".format(i) for i in range(1, 3)] + + +# ---------------------------------------------------------- def load(database_name): database = Database(prefix, database_name) - assert database.valid + nose.tools.assert_true(database.valid, "\n * %s" % "\n * ".join(database.errors)) return database -#---------------------------------------------------------- +# ---------------------------------------------------------- def test_load_valid_database(): - database = Database(prefix, 'integers_db/1') - assert database.valid, '\n * %s' % '\n * '.join(database.errors) + for db_name in INTEGERS_DBS: + yield load_valid_database, db_name + +def load_valid_database(db_name): + database = load(db_name) nose.tools.eq_(len(database.sets("double")), 1) nose.tools.eq_(len(database.sets("triple")), 1) nose.tools.eq_(len(database.sets("two_sets")), 2) -#---------------------------------------------------------- +# ---------------------------------------------------------- def test_load_protocol_with_one_set(): - database = Database(prefix, 'integers_db/1') + for db_name in INTEGERS_DBS: + yield load_valid_database, db_name + + +def load_protocol_with_one_set(db_name): + + database = load(db_name) protocol = database.protocol("double") - nose.tools.eq_(len(protocol['sets']), 1) + nose.tools.eq_(len(protocol["sets"]), 1) - set = database.set("double", "double") + set_ = database.set("double", "double") - nose.tools.eq_(set['name'], 'double') - nose.tools.eq_(len(set['outputs']), 3) + nose.tools.eq_(set_["name"], "double") + nose.tools.eq_(len(set_["outputs"]), 3) - assert set['outputs']['a'] is not None - assert set['outputs']['b'] is not None - assert set['outputs']['sum'] is not None + nose.tools.assert_is_not_none(set_["outputs"]["a"]) + nose.tools.assert_is_not_none(set_["outputs"]["b"]) + nose.tools.assert_is_not_none(set_["outputs"]["sum"]) -#---------------------------------------------------------- +# ---------------------------------------------------------- def test_load_protocol_with_two_sets(): - database = Database(prefix, 'integers_db/1') + for db_name in INTEGERS_DBS: + yield load_valid_database, db_name + + +def load_protocol_with_two_sets(db_name): + + database = load(db_name) protocol = database.protocol("two_sets") - nose.tools.eq_(len(protocol['sets']), 2) + nose.tools.eq_(len(protocol["sets"]), 2) - set = database.set("two_sets", "double") + set_ = database.set("two_sets", "double") - nose.tools.eq_(set['name'], 'double') - nose.tools.eq_(len(set['outputs']), 3) + nose.tools.eq_(set["name"], "double") + nose.tools.eq_(len(set["outputs"]), 3) - assert set['outputs']['a'] is not None - assert set['outputs']['b'] is not None - assert set['outputs']['sum'] is not None + nose.tools.assert_is_not_none(set_["outputs"]["a"]) + nose.tools.assert_is_not_none(set_["outputs"]["b"]) + nose.tools.assert_is_not_none(set_["outputs"]["sum"]) - set = database.set("two_sets", "triple") + set_ = database.set("two_sets", "triple") - nose.tools.eq_(set['name'], 'triple') - nose.tools.eq_(len(set['outputs']), 4) + nose.tools.eq_(set_["name"], "triple") + nose.tools.eq_(len(set_["outputs"]), 4) - assert set['outputs']['a'] is not None - assert set['outputs']['b'] is not None - assert set['outputs']['c'] is not None - assert set['outputs']['sum'] is not None + nose.tools.assert_is_not_none(set_["outputs"]["a"]) + nose.tools.assert_is_not_none(set_["outputs"]["b"]) + nose.tools.assert_is_not_none(set_["outputs"]["c"]) + nose.tools.assert_is_not_none(set_["outputs"]["sum"]) diff --git a/beat/backend/python/test/test_database_view.py b/beat/backend/python/test/test_database_view.py index c24f5a93ed8fe98145fc6505f37627cd8c730b40..77c1df7b5b1d4017d428b9ce0fae0b7d43fa30e6 100644 --- a/beat/backend/python/test/test_database_view.py +++ b/beat/backend/python/test/test_database_view.py @@ -39,116 +39,114 @@ import tempfile import shutil import os +from ddt import ddt +from ddt import idata + from ..database import Database +from .test_database import INTEGERS_DBS + from . import prefix -#---------------------------------------------------------- +# ---------------------------------------------------------- class MyExc(Exception): pass -#---------------------------------------------------------- +# ---------------------------------------------------------- +@ddt class TestDatabaseViewRunner(unittest.TestCase): - def setUp(self): self.cache_root = tempfile.mkdtemp(prefix=__name__) - def tearDown(self): shutil.rmtree(self.cache_root) - def test_syntax_error(self): - db = Database(prefix, 'syntax_error/1') + db = Database(prefix, "syntax_error/1") self.assertTrue(db.valid) with self.assertRaises(SyntaxError): - view = db.view('protocol', 'set') - + db.view("protocol", "set") def test_unknown_view(self): - db = Database(prefix, 'integers_db/1') + db = Database(prefix, "integers_db/1") self.assertTrue(db.valid) with self.assertRaises(KeyError): - view = db.view('protocol', 'does_not_exist') + db.view("protocol", "does_not_exist") - - def test_valid_view(self): - db = Database(prefix, 'integers_db/1') + @idata(INTEGERS_DBS) + def test_valid_view(self, db_name): + db = Database(prefix, db_name) self.assertTrue(db.valid) - view = db.view('double', 'double') + view = db.view("double", "double") self.assertTrue(view is not None) - def test_indexing_crash(self): - db = Database(prefix, 'crash/1') + db = Database(prefix, "crash/1") self.assertTrue(db.valid) - view = db.view('protocol', 'index_crashes', MyExc) + view = db.view("protocol", "index_crashes", MyExc) with self.assertRaises(MyExc): - view.index(os.path.join(self.cache_root, 'data.db')) - + view.index(os.path.join(self.cache_root, "data.db")) def test_get_crash(self): - db = Database(prefix, 'crash/1') + db = Database(prefix, "crash/1") self.assertTrue(db.valid) - view = db.view('protocol', 'get_crashes', MyExc) - view.index(os.path.join(self.cache_root, 'data.db')) - view.setup(os.path.join(self.cache_root, 'data.db')) + view = db.view("protocol", "get_crashes", MyExc) + view.index(os.path.join(self.cache_root, "data.db")) + view.setup(os.path.join(self.cache_root, "data.db")) with self.assertRaises(MyExc): - view.get('a', 0) - + view.get("a", 0) def test_not_setup(self): - db = Database(prefix, 'crash/1') + db = Database(prefix, "crash/1") self.assertTrue(db.valid) - view = db.view('protocol', 'get_crashes', MyExc) + view = db.view("protocol", "get_crashes", MyExc) with self.assertRaises(MyExc): - view.get('a', 0) - + view.get("a", 0) - def test_success(self): - db = Database(prefix, 'integers_db/1') + @idata(INTEGERS_DBS) + def test_success(self, db_name): + db = Database(prefix, db_name) self.assertTrue(db.valid) - view = db.view('double', 'double', MyExc) - view.index(os.path.join(self.cache_root, 'data.db')) - view.setup(os.path.join(self.cache_root, 'data.db')) + view = db.view("double", "double", MyExc) + view.index(os.path.join(self.cache_root, "data.db")) + view.setup(os.path.join(self.cache_root, "data.db")) self.assertTrue(view.data_sources is not None) self.assertEqual(len(view.data_sources), 3) for i in range(0, 9): - self.assertEqual(view.get('a', i)['value'], i + 1) - self.assertEqual(view.get('b', i)['value'], (i + 1) * 10) - self.assertEqual(view.get('sum', i)['value'], (i + 1) * 10 + i + 1) - + self.assertEqual(view.get("a", i)["value"], i + 1) + self.assertEqual(view.get("b", i)["value"], (i + 1) * 10) + self.assertEqual(view.get("sum", i)["value"], (i + 1) * 10 + i + 1) def test_success_using_keywords(self): - db = Database(prefix, 'python_keyword/1') + db = Database(prefix, "python_keyword/1") self.assertTrue(db.valid) - view = db.view('keyword', 'keyword', MyExc) - view.index(os.path.join(self.cache_root, 'data.db')) - view.setup(os.path.join(self.cache_root, 'data.db')) + view = db.view("keyword", "keyword", MyExc) + view.index(os.path.join(self.cache_root, "data.db")) + view.setup(os.path.join(self.cache_root, "data.db")) self.assertTrue(view.data_sources is not None) self.assertEqual(len(view.data_sources), 3) for i in range(0, 9): - self.assertEqual(view.get('class', i)['value'], i + 1) - self.assertEqual(view.get('def', i)['value'], (i + 1) * 10) - self.assertEqual(view.get('sum', i)['value'], (i + 1) * 10 + i + 1) + self.assertEqual(view.get("class", i)["value"], i + 1) + self.assertEqual(view.get("def", i)["value"], (i + 1) * 10) + self.assertEqual(view.get("sum", i)["value"], (i + 1) * 10 + i + 1) diff --git a/beat/backend/python/test/test_databases_index.py b/beat/backend/python/test/test_databases_index.py index d842874f5a1e9402a3f1eb1634c047d23aff5a6e..1cbaafb7961adfff9f08dbf88d532a8caeea0559 100644 --- a/beat/backend/python/test/test_databases_index.py +++ b/beat/backend/python/test/test_databases_index.py @@ -43,48 +43,49 @@ import multiprocessing import tempfile import shutil +from ddt import ddt +from ddt import idata + from ..scripts import index from ..hash import hashDataset from ..hash import toPath +from .test_database import INTEGERS_DBS + from . import prefix -#---------------------------------------------------------- +# ---------------------------------------------------------- class IndexationProcess(multiprocessing.Process): - def __init__(self, queue, arguments): super(IndexationProcess, self).__init__() self.queue = queue self.arguments = arguments - def run(self): - self.queue.put('STARTED') + self.queue.put("STARTED") index.main(self.arguments) -#---------------------------------------------------------- +# ---------------------------------------------------------- +@ddt class TestDatabaseIndexation(unittest.TestCase): - - def __init__(self, methodName='runTest'): + def __init__(self, methodName="runTest"): super(TestDatabaseIndexation, self).__init__(methodName) self.databases_indexation_process = None self.working_dir = None self.cache_root = None - def setUp(self): self.shutdown_everything() # In case another test failed badly during its setUp() self.working_dir = tempfile.mkdtemp(prefix=__name__) self.cache_root = tempfile.mkdtemp(prefix=__name__) - def tearDown(self): self.shutdown_everything() @@ -95,7 +96,6 @@ class TestDatabaseIndexation(unittest.TestCase): self.cache_root = None self.data_source = None - def shutdown_everything(self): if self.databases_indexation_process is not None: self.databases_indexation_process.terminate() @@ -103,13 +103,8 @@ class TestDatabaseIndexation(unittest.TestCase): del self.databases_indexation_process self.databases_indexation_process = None - def process(self, database, protocol_name=None, set_name=None): - args = [ - prefix, - self.cache_root, - database, - ] + args = [prefix, self.cache_root, database] if protocol_name is not None: args.append(protocol_name) @@ -117,7 +112,9 @@ class TestDatabaseIndexation(unittest.TestCase): if set_name is not None: args.append(set_name) - self.databases_indexation_process = IndexationProcess(multiprocessing.Queue(), args) + self.databases_indexation_process = IndexationProcess( + multiprocessing.Queue(), args + ) self.databases_indexation_process.start() self.databases_indexation_process.queue.get() @@ -126,60 +123,63 @@ class TestDatabaseIndexation(unittest.TestCase): del self.databases_indexation_process self.databases_indexation_process = None + @idata(INTEGERS_DBS) + def test_one_set(self, db_name): + self.process(db_name, "double", "double") - def test_one_set(self): - self.process('integers_db/1', 'double', 'double') - - expected_files = [ - hashDataset('integers_db/1', 'double', 'double') - ] - + expected_files = [hashDataset(db_name, "double", "double")] + print(expected_files) for filename in expected_files: - self.assertTrue(os.path.exists(os.path.join(self.cache_root, - toPath(filename, suffix='.db')) - )) + self.assertTrue( + os.path.exists( + os.path.join(self.cache_root, toPath(filename, suffix=".db")) + ) + ) - - def test_one_protocol(self): - self.process('integers_db/1', 'two_sets') + @idata(INTEGERS_DBS) + def test_one_protocol(self, db_name): + self.process(db_name, "two_sets") expected_files = [ - hashDataset('integers_db/1', 'two_sets', 'double'), - hashDataset('integers_db/1', 'two_sets', 'triple') + hashDataset(db_name, "two_sets", "double"), + hashDataset(db_name, "two_sets", "triple"), ] for filename in expected_files: - self.assertTrue(os.path.exists(os.path.join(self.cache_root, - toPath(filename, suffix='.db')) - )) - + self.assertTrue( + os.path.exists( + os.path.join(self.cache_root, toPath(filename, suffix=".db")) + ) + ) - def test_whole_database(self): - self.process('integers_db/1') + @idata(INTEGERS_DBS) + def test_whole_database(self, db_name): + self.process(db_name) expected_files = [ - hashDataset('integers_db/1', 'double', 'double'), - hashDataset('integers_db/1', 'triple', 'triple'), - hashDataset('integers_db/1', 'two_sets', 'double'), - hashDataset('integers_db/1', 'two_sets', 'triple'), - hashDataset('integers_db/1', 'labelled', 'labelled'), - hashDataset('integers_db/1', 'different_frequencies', 'double'), + hashDataset(db_name, "double", "double"), + hashDataset(db_name, "triple", "triple"), + hashDataset(db_name, "two_sets", "double"), + hashDataset(db_name, "two_sets", "triple"), + hashDataset(db_name, "labelled", "labelled"), + hashDataset(db_name, "different_frequencies", "double"), ] for filename in expected_files: - self.assertTrue(os.path.exists(os.path.join(self.cache_root, - toPath(filename, suffix='.db')) - )) - + self.assertTrue( + os.path.exists( + os.path.join(self.cache_root, toPath(filename, suffix=".db")) + ) + ) def test_error(self): - self.process('crash/1', 'protocol', 'index_crashes') + self.process("crash/1", "protocol", "index_crashes") - unexpected_files = [ - hashDataset('crash/1', 'protocol', 'index_crashes'), - ] + unexpected_files = [hashDataset("crash/1", "protocol", "index_crashes")] for filename in unexpected_files: - self.assertFalse(os.path.exists(os.path.join(self.cache_root, - toPath(filename, suffix='.db')) - )) + self.assertFalse( + os.path.exists( + os.path.join(self.cache_root, toPath(filename, suffix=".db")) + ) + ) diff --git a/beat/backend/python/test/test_databases_provider.py b/beat/backend/python/test/test_databases_provider.py index f66f0bca1ec6ddaa020145e4ffd043a5af1178e8..64163f5e2c46688b90a7071cfcd423d2fe9e116e 100644 --- a/beat/backend/python/test/test_databases_provider.py +++ b/beat/backend/python/test/test_databases_provider.py @@ -47,8 +47,8 @@ import tempfile import shutil import zmq -from time import time -from time import sleep +from ddt import ddt +from ddt import idata from contextlib import closing @@ -57,107 +57,101 @@ from ..database import Database from ..data import RemoteDataSource from ..data import RemoteException +from .test_database import INTEGERS_DBS + from . import prefix logger = logging.getLogger(__name__) -#---------------------------------------------------------- +# ---------------------------------------------------------- CONFIGURATION = { - 'queue': 'queue', - 'inputs': { - 'in_data': { - 'set': 'double', - 'protocol': 'double', - 'database': 'integers_db/1', - 'output': 'a', - 'path': 'ec/89/e5/6e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55', - 'endpoint': 'a', - 'hash': 'ec89e56e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55', - 'channel': 'integers' + "queue": "queue", + "inputs": { + "in_data": { + "set": "double", + "protocol": "double", + "database": "integers_db/1", + "output": "a", + "path": "ec/89/e5/6e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55", + "endpoint": "a", + "hash": "ec89e56e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55", + "channel": "integers", } }, - 'algorithm': 'user/integers_echo/1', - 'parameters': {}, - 'environment': { - 'name': 'Python 2.7', - 'version': '1.2.0' - }, - 'outputs': { - 'out_data': { - 'path': '20/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681', - 'endpoint': 'out_data', - 'hash': '2061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681', - 'channel': 'integers' + "algorithm": "user/integers_echo/1", + "parameters": {}, + "environment": {"name": "Python 2.7", "version": "1.2.0"}, + "outputs": { + "out_data": { + "path": "20/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681", + "endpoint": "out_data", + "hash": "2061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681", + "channel": "integers", } }, - 'nb_slots': 1, - 'channel': 'integers' + "nb_slots": 1, + "channel": "integers", } -#---------------------------------------------------------- +# ---------------------------------------------------------- CONFIGURATION_ERROR = { - 'queue': 'queue', - 'inputs': { - 'in_data': { - 'set': 'get_crashes', - 'protocol': 'protocol', - 'database': 'crash/1', - 'output': 'out', - 'path': 'ec/89/e5/6e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55', - 'endpoint': 'in', - 'hash': 'ec89e56e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55', - 'channel': 'set' + "queue": "queue", + "inputs": { + "in_data": { + "set": "get_crashes", + "protocol": "protocol", + "database": "crash/1", + "output": "out", + "path": "ec/89/e5/6e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55", + "endpoint": "in", + "hash": "ec89e56e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55", + "channel": "set", } }, - 'algorithm': 'user/integers_echo/1', - 'parameters': {}, - 'environment': { - 'name': 'Python 2.7', - 'version': '1.2.0' - }, - 'outputs': { - 'out_data': { - 'path': '20/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681', - 'endpoint': 'out_data', - 'hash': '2061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681', - 'channel': 'set' + "algorithm": "user/integers_echo/1", + "parameters": {}, + "environment": {"name": "Python 2.7", "version": "1.2.0"}, + "outputs": { + "out_data": { + "path": "20/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681", + "endpoint": "out_data", + "hash": "2061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681", + "channel": "set", } }, - 'nb_slots': 1, - 'channel': 'set' + "nb_slots": 1, + "channel": "set", } -#---------------------------------------------------------- +# ---------------------------------------------------------- class DatabasesProviderProcess(multiprocessing.Process): - def __init__(self, queue, arguments): super(DatabasesProviderProcess, self).__init__() self.queue = queue self.arguments = arguments - def run(self): - self.queue.put('STARTED') + self.queue.put("STARTED") databases_provider.main(self.arguments) -#---------------------------------------------------------- +# ---------------------------------------------------------- +@ddt class TestDatabasesProvider(unittest.TestCase): - - def __init__(self, methodName='runTest'): + def __init__(self, methodName="runTest"): super(TestDatabasesProvider, self).__init__(methodName) self.databases_provider_process = None self.working_dir = None @@ -166,13 +160,11 @@ class TestDatabasesProvider(unittest.TestCase): self.client_context = None self.client_socket = None - def setUp(self): self.shutdown_everything() # In case another test failed badly during its setUp() self.working_dir = tempfile.mkdtemp(prefix=__name__) self.cache_root = tempfile.mkdtemp(prefix=__name__) - def tearDown(self): self.shutdown_everything() @@ -188,54 +180,51 @@ class TestDatabasesProvider(unittest.TestCase): self.client_socket.close() self.client_context.destroy() - def shutdown_everything(self): self.stop_databases_provider() - def start_databases_provider(self, configuration): - with open(os.path.join(self.working_dir, 'configuration.json'), 'wb') as f: + with open(os.path.join(self.working_dir, "configuration.json"), "wb") as f: data = json.dumps(configuration, indent=4) - f.write(data.encode('utf-8')) + f.write(data.encode("utf-8")) - working_prefix = os.path.join(self.working_dir, 'prefix') + working_prefix = os.path.join(self.working_dir, "prefix") if not os.path.exists(working_prefix): os.makedirs(working_prefix) - input_name, input_cfg = list(configuration['inputs'].items())[0] + input_name, input_cfg = list(configuration["inputs"].items())[0] - database = Database(prefix, input_cfg['database']) + database = Database(prefix, input_cfg["database"]) database.export(working_prefix) - view = database.view(input_cfg['protocol'], input_cfg['set']) - view.index(os.path.join(self.cache_root, input_cfg['path'])) + view = database.view(input_cfg["protocol"], input_cfg["set"]) + view.index(os.path.join(self.cache_root, input_cfg["path"])) with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: - s.bind(('', 0)) + s.bind(("", 0)) port = s.getsockname()[1] - address = '127.0.0.1:%i' % port - args = [ - address, - self.working_dir, - self.cache_root, - ] + address = "127.0.0.1:%i" % port + args = [address, self.working_dir, self.cache_root] - self.databases_provider_process = DatabasesProviderProcess(multiprocessing.Queue(), args) + self.databases_provider_process = DatabasesProviderProcess( + multiprocessing.Queue(), args + ) self.databases_provider_process.start() self.databases_provider_process.queue.get() self.client_context = zmq.Context() self.client_socket = self.client_context.socket(zmq.PAIR) - self.client_socket.connect('tcp://%s' % address) + self.client_socket.connect("tcp://%s" % address) - dataformat_name = database.set(input_cfg['protocol'], input_cfg['set'])['outputs'][input_cfg['output']] + dataformat_name = database.set(input_cfg["protocol"], input_cfg["set"])[ + "outputs" + ][input_cfg["output"]] self.data_source = RemoteDataSource() self.data_source.setup(self.client_socket, input_name, dataformat_name, prefix) - def stop_databases_provider(self): if self.databases_provider_process is not None: self.databases_provider_process.terminate() @@ -243,8 +232,9 @@ class TestDatabasesProvider(unittest.TestCase): del self.databases_provider_process self.databases_provider_process = None - - def test_success(self): + @idata(INTEGERS_DBS) + def test_success(self, db_name): + CONFIGURATION["inputs"]["in_data"]["database"] = db_name self.start_databases_provider(CONFIGURATION) self.assertEqual(len(self.data_source), 9) @@ -255,7 +245,6 @@ class TestDatabasesProvider(unittest.TestCase): self.assertEqual(end_index, i) self.assertEqual(data.value, i + 1) - def test_error(self): self.start_databases_provider(CONFIGURATION_ERROR) diff --git a/beat/backend/python/test/test_dbexecutor.py b/beat/backend/python/test/test_dbexecutor.py index b046b7e3a8ce8110ad990a11b027f4b61301284d..7b84072def4fad5f436551cea3e9c894d99acae7 100644 --- a/beat/backend/python/test/test_dbexecutor.py +++ b/beat/backend/python/test/test_dbexecutor.py @@ -38,13 +38,14 @@ import os import logging -logger = logging.getLogger(__name__) - import unittest import zmq import tempfile import shutil +from ddt import ddt +from ddt import idata + from ..execution import DBExecutor from ..execution import MessageHandler from ..database import Database @@ -53,80 +54,80 @@ from ..data import RemoteDataSource from ..hash import hashDataset from ..hash import toPath +from .test_database import INTEGERS_DBS + from . import prefix -#---------------------------------------------------------- +# ---------------------------------------------------------- -DB_VIEW_HASH = hashDataset('integers_db/1', 'double', 'double') -DB_INDEX_PATH = toPath(DB_VIEW_HASH, suffix='.db') +logger = logging.getLogger(__name__) + CONFIGURATION = { - 'queue': 'queue', - 'algorithm': 'user/sum/1', - 'nb_slots': 1, - 'channel': 'integers', - 'parameters': { - }, - 'environment': { - 'name': 'Python 2.7', - 'version': '1.2.0' - }, - 'inputs': { - 'a': { - 'database': 'integers_db/1', - 'protocol': 'double', - 'set': 'double', - 'output': 'a', - 'endpoint': 'a', - 'channel': 'integers', - 'path': DB_INDEX_PATH, - 'hash': DB_VIEW_HASH, + "queue": "queue", + "algorithm": "user/sum/1", + "nb_slots": 1, + "channel": "integers", + "parameters": {}, + "environment": {"name": "Python 2.7", "version": "1.2.0"}, + "inputs": { + "a": { + "database": "integers_db/1", + "protocol": "double", + "set": "double", + "output": "a", + "endpoint": "a", + "channel": "integers", + "path": None, + "hash": None, + }, + "b": { + "database": "integers_db/1", + "protocol": "double", + "set": "double", + "output": "b", + "endpoint": "b", + "channel": "integers", + "path": None, + "hash": None, }, - 'b': { - 'database': 'integers_db/1', - 'protocol': 'double', - 'set': 'double', - 'output': 'b', - 'endpoint': 'b', - 'channel': 'integers', - 'path': DB_INDEX_PATH, - 'hash': DB_VIEW_HASH, - } }, - 'outputs': { - 'sum': { - 'endpoint': 'sum', - 'channel': 'integers', - 'path': '20/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681', - 'hash': '2061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681', + "outputs": { + "sum": { + "endpoint": "sum", + "channel": "integers", + "path": "20/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681", + "hash": "2061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681", } - } + }, } -#---------------------------------------------------------- +# ---------------------------------------------------------- +@ddt class TestExecution(unittest.TestCase): - def setUp(self): self.cache_root = tempfile.mkdtemp(prefix=__name__) - database = Database(prefix, 'integers_db/1') - view = database.view('double', 'double') + for db_name in INTEGERS_DBS: + database = Database(prefix, db_name) + view = database.view("double", "double") + db_view_hash = hashDataset(db_name, "double", "double") + db_index_path = toPath(db_view_hash, suffix=".db") - view.index(os.path.join(self.cache_root, DB_INDEX_PATH)) + view.index(os.path.join(self.cache_root, db_index_path)) self.db_executor = None self.client_context = None self.client_socket = None - def tearDown(self): if self.client_socket is not None: - self.client_socket.send_string('don') + self.client_socket.send_string("don") if self.db_executor is not None: self.db_executor.wait() @@ -138,11 +139,21 @@ class TestExecution(unittest.TestCase): shutil.rmtree(self.cache_root) + @idata(INTEGERS_DBS) + def test_success(self, db_name): + message_handler = MessageHandler("127.0.0.1") + + for input_ in ["a", "b"]: + db_view_hash = hashDataset(db_name, "double", "double") + db_index_path = toPath(db_view_hash, suffix=".db") - def test_success(self): - message_handler = MessageHandler('127.0.0.1') + CONFIGURATION["inputs"][input_]["database"] = db_name + CONFIGURATION["inputs"][input_]["path"] = db_index_path + CONFIGURATION["inputs"][input_]["hash"] = db_view_hash - self.db_executor = DBExecutor(message_handler, prefix, self.cache_root, CONFIGURATION) + self.db_executor = DBExecutor( + message_handler, prefix, self.cache_root, CONFIGURATION + ) self.assertTrue(self.db_executor.valid) @@ -152,17 +163,18 @@ class TestExecution(unittest.TestCase): self.client_socket = self.client_context.socket(zmq.PAIR) self.client_socket.connect(self.db_executor.address) - data_loader = DataLoader(CONFIGURATION['channel']) + data_loader = DataLoader(CONFIGURATION["channel"]) - database = Database(prefix, 'integers_db/1') + database = Database(prefix, db_name) - for input_name, input_conf in CONFIGURATION['inputs'].items(): - dataformat_name = database.set(input_conf['protocol'], input_conf['set'])['outputs'][input_conf['output']] + for input_name, input_conf in CONFIGURATION["inputs"].items(): + dataformat_name = database.set(input_conf["protocol"], input_conf["set"])[ + "outputs" + ][input_conf["output"]] data_source = RemoteDataSource() data_source.setup(self.client_socket, input_name, dataformat_name, prefix) data_loader.add(input_name, data_source) - - self.assertEqual(data_loader.count('a'), 9) - self.assertEqual(data_loader.count('b'), 9) + self.assertEqual(data_loader.count("a"), 9) + self.assertEqual(data_loader.count("b"), 9)