diff --git a/beat/backend/python/data.py b/beat/backend/python/data.py index b7883b5d99744fcefec949e1d2b31d1202c88b57..4a75ec3456a29f3ef0f0eeb522f4ff964eeb9700 100755 --- a/beat/backend/python/data.py +++ b/beat/backend/python/data.py @@ -36,6 +36,7 @@ import select import time import tempfile import abc +import zmq from functools import reduce from collections import namedtuple @@ -51,6 +52,22 @@ from .algorithm import Algorithm #---------------------------------------------------------- +class RemoteException(Exception): + + def __init__(self, kind, message): + super(RemoteException, self).__init__() + + if kind == 'sys': + self.system_error = message + self.user_error = '' + else: + self.system_error = '' + self.user_error = message + + +#---------------------------------------------------------- + + def mixDataIndices(list_of_data_indices): """Given a collection of lists of data indices (belonging to separate but synchronized files/inputs), returns the most granular list of @@ -171,20 +188,78 @@ def getAllFilenames(filename, start_index=None, end_index=None): #---------------------------------------------------------- -class CachedFileLoader(object): +class DataSource(object): + """Base class to load data from some source""" + + def __init__(self): + self.infos = [] + self.read_duration = 0 + self.nb_bytes_read = 0 + + + def close(self): + self.infos = [] + + + def __del__(self): + """Makes sure all resources are released when the object is deleted""" + self.close() + + + def __len__(self): + return len(self.infos) + + + def __iter__(self): + for i in range(0, len(self.infos)): + yield self[i] + + + def __getitem__(self, index): + raise NotImplemented() + + + def first_data_index(self): + return self.infos[0].start_index + + + def last_data_index(self): + return self.infos[-1].end_index + + + def data_indices(self): + return [ (x.start_index, x.end_index) for x in self.infos ] + + + def getAtDataIndex(self, data_index): + for index, infos in enumerate(self.infos): + if (infos.start_index <= data_index) and (data_index <= infos.end_index): + return self[index] + + return (None, None, None) + + + def statistics(self): + """Return the statistics about the number of bytes read""" + return (self.nb_bytes_read, self.read_duration) + + +#---------------------------------------------------------- + + +class CachedDataSource(DataSource): """Utility class to load data from a file in the cache""" def __init__(self): + super(CachedDataSource, self).__init__() + self.filenames = None self.encoding = None # Must be 'binary' or 'json' self.prefix = None self.dataformat = None - self.infos = [] self.current_file = None self.current_file_index = None self.unpack = True - self.read_duration = 0 - self.nb_bytes_read = 0 def _readHeader(self, file): @@ -232,10 +307,10 @@ class CachedFileLoader(object): prefix (str, path): Path to the prefix where the dataformats are stored. - force_start_index (int): The starting index (if not set or set to + start_index (int): The starting index (if not set or set to ``None``, the default, read data from the begin of file) - force_end_index (int): The end index (if not set or set to ``None``, the + end_index (int): The end index (if not set or set to ``None``, the default, reads the data until the end) unpack (bool): Indicates if the data must be unpacked or not @@ -345,19 +420,7 @@ class CachedFileLoader(object): if self.current_file is not None: self.current_file.close() - - def __del__(self): - """Makes sure the files are close when the object is deleted""" - self.close() - - - def __len__(self): - return len(self.infos) - - - def __iter__(self): - for i in range(0, len(self.infos)): - yield self[i] + super(CachedDataSource, self).close() def __getitem__(self, index): @@ -403,35 +466,251 @@ class CachedFileLoader(object): return (data, infos.start_index, infos.end_index) - def first_data_index(self): - return self.infos[0].start_index +#---------------------------------------------------------- - def last_data_index(self): - return self.infos[-1].end_index +class DatabaseOutputDataSource(DataSource): + """Utility class to load data from an output of a database view""" + def __init__(self): + super(DatabaseOutputDataSource, self).__init__() - def data_indices(self): - return [ (x.start_index, x.end_index) for x in self.infos ] + self.prefix = None + self.dataformat = None + self.view = None + self.output_name = None + self.pack = True - def getAtDataIndex(self, data_index): - for index, infos in enumerate(self.infos): - if (infos.start_index <= data_index) and (data_index <= infos.end_index): - return self[index] + def setup(self, view, output_name, dataformat_name, prefix, start_index=None, + end_index=None, pack=False): + """Configures the data source - return (None, None, None) + Parameters: - def statistics(self): - """Return the statistics about the number of bytes read from the files""" - return (self.nb_bytes_read, self.read_duration) + prefix (str, path): Path to the prefix where the dataformats are stored. + + start_index (int): The starting index (if not set or set to + ``None``, the default, read data from the begin of file) + + end_index (int): The end index (if not set or set to ``None``, the + default, reads the data until the end) + + unpack (bool): Indicates if the data must be unpacked or not + + + Returns: + + ``True``, if successful, or ``False`` otherwise. + + """ + + self.prefix = prefix + self.view = view + self.output_name = output_name + self.pack = pack + + self.dataformat = DataFormat(self.prefix, dataformat_name) + + if not self.dataformat.valid: + raise RuntimeError("the dataformat `%s' is not valid" % dataformat_name) + + + # Load all the needed infos from all the files + Infos = namedtuple('Infos', ['start_index', 'end_index']) + + objects = self.view.objects() + + start = None + end = None + previous_value = None + + for index, obj in enumerate(objects): + if start is None: + start = index + previous_value = getattr(obj, output_name) + elif getattr(obj, output_name) != previous_value: + end = index - 1 + previous_value = None + + if ((start_index is None) or (start >= start_index)) and \ + ((end_index is None) or (end <= end_index)): + self.infos.append(Infos(start_index=start, end_index=end)) + + start = index + previous_value = getattr(obj, output_name) + + end = index + + if ((start_index is None) or (start >= start_index)) and \ + ((end_index is None) or (end <= end_index)): + self.infos.append(Infos(start_index=start, end_index=end)) + + return True + + + def __getitem__(self, index): + """Retrieve a block of data + + Returns: + + A tuple (data, start_index, end_index) + + """ + + if (index < 0) or (index >= len(self.infos)): + return (None, None, None) + + infos = self.infos[index] + + t1 = time.time() + data = self.view.get(self.output_name, infos.start_index) + t2 = time.time() + + self.read_duration += t2 - t1 + + if isinstance(data, dict): + d = self.dataformat.type() + d.from_dict(data, casting='safe', add_defaults=False) + data = d + + if self.pack: + data = data.pack() + self.nb_bytes_read += len(data) + + return (data, infos.start_index, infos.end_index) #---------------------------------------------------------- -class DataSource(object): +class RemoteDataSource(DataSource): + """Utility class to load data from a data source accessible via a socket""" + + def __init__(self): + super(RemoteDataSource, self).__init__() + + self.socket = None + self.input_name = None + self.dataformat = None + self.unpack = True + + + def setup(self, socket, input_name, dataformat_name, prefix, unpack=True): + """Configures the data source + + + Parameters: + + socket (socket): The socket to use to access the data. + + input_name (str): Name of the input corresponding to the data source. + + dataformat_name (str): Name of the data format. + + prefix (str, path): Path to the prefix where the dataformats are stored. + + unpack (bool): Indicates if the data must be unpacked or not + + + Returns: + + ``True``, if successful, or ``False`` otherwise. + + """ + + self.socket = socket + self.input_name = input_name + self.unpack = unpack + + self.dataformat = DataFormat(prefix, dataformat_name) + + if not self.dataformat.valid: + raise RuntimeError("the dataformat `%s' is not valid" % dataformat_name) + + + # Load the needed infos from the socket + Infos = namedtuple('Infos', ['start_index', 'end_index']) + + logger.debug('send: (ifo) infos %s', self.input_name) + + self.socket.send('ifo', zmq.SNDMORE) + self.socket.send(self.input_name) + + answer = self.socket.recv() + logger.debug('recv: %s', answer) + + if answer == 'err': + kind = self.socket.recv() + message = self.socket.recv() + raise RemoteException(kind, message) + + nb = int(answer) + for i in range(nb): + start = int(self.socket.recv()) + end = int(self.socket.recv()) + self.infos.append(Infos(start_index=start, end_index=end)) + + return True + + + def __getitem__(self, index): + """Retrieve a block of data + + Returns: + + A tuple (data, start_index, end_index) + + """ + + if (index < 0) or (index >= len(self.infos)): + return (None, None, None) + + infos = self.infos[index] + + logger.debug('send: (get) get %s %d', self.input_name, index) + + t1 = time.time() + + self.socket.send('get', zmq.SNDMORE) + self.socket.send(self.input_name, zmq.SNDMORE) + self.socket.send('%d' % index) + + answer = self.socket.recv() + + if answer == 'err': + self.read_duration += time.time() - _start + kind = self.socket.recv() + message = self.socket.recv() + raise RemoteException(kind, message) + + start = int(answer) + end = int(self.socket.recv()) + + packed = self.socket.recv() + + t2 = time.time() + + logger.debug('recv: <bin> (size=%d), indexes=(%d, %d)', len(packed), start, end) + + self.nb_bytes_read += len(packed) + + if self.unpack: + data = self.dataformat.type() + data.unpack(packed) + else: + data = packed + + self.read_duration += t2 - t1 + + return (data, infos.start_index, infos.end_index) + + +#---------------------------------------------------------- + + +class LegacyDataSource(object): """Interface of all the Data Sources @@ -545,7 +824,7 @@ class StdoutDataSink(DataSink): #---------------------------------------------------------- -class CachedDataSource(DataSource): +class CachedLegacyDataSource(LegacyDataSource): """Data Source that load data from the Cache""" def __init__(self): @@ -580,7 +859,7 @@ class CachedDataSource(DataSource): """ - self.cached_file = CachedFileLoader() + self.cached_file = CachedDataSource() if self.cached_file.setup(filename, prefix, start_index=force_start_index, end_index=force_end_index, unpack=unpack): self.dataformat = self.cached_file.dataformat @@ -814,7 +1093,7 @@ class CachedDataSink(DataSink): #---------------------------------------------------------- -class MemoryDataSource(DataSource): +class MemoryLegacyDataSource(LegacyDataSource): """Interface of all the Data Sources @@ -864,7 +1143,7 @@ class MemoryDataSource(DataSource): class MemoryDataSink(DataSink): - """Data Sink that directly transmit data to associated MemoryDataSource + """Data Sink that directly transmit data to associated MemoryLegacyDataSource objects. """ @@ -874,7 +1153,7 @@ class MemoryDataSink(DataSink): def setup(self, data_sources): """Configure the data sink - :param list data_sources: The MemoryDataSource objects to use + :param list data_sources: The MemoryLegacyDataSource objects to use """ self.data_sources = data_sources diff --git a/beat/backend/python/database.py b/beat/backend/python/database.py index ff81a936626fe6d3c95e5bd41e2898087e0fde08..ad881fe41609256ebbbed8eced7956e2a1f08654 100755 --- a/beat/backend/python/database.py +++ b/beat/backend/python/database.py @@ -35,6 +35,7 @@ import six import simplejson import itertools import numpy as np +from collections import namedtuple from . import loader from . import utils @@ -73,7 +74,7 @@ class Storage(utils.CodeStorage): #---------------------------------------------------------- -class View(object): +class Runner(object): '''A special loader class for database views, with specialized methods Parameters: @@ -98,9 +99,7 @@ class View(object): ''' - - def __init__(self, module, definition, prefix, root_folder, exc=None, - *args, **kwargs): + def __init__(self, module, definition, prefix, root_folder, exc=None): try: class_ = getattr(module, definition['view']) @@ -109,94 +108,73 @@ class View(object): type, value, traceback = sys.exc_info() six.reraise(exc, exc(value), traceback) else: - raise #just re-raise the user exception + raise # just re-raise the user exception - self.obj = loader.run(class_, '__new__', exc, *args, **kwargs) - self.ready = False - self.prefix = prefix - self.root_folder = root_folder - self.definition = definition - self.exc = exc or RuntimeError - self.outputs = None + self.obj = loader.run(class_, '__new__', exc) + self.ready = False + self.prefix = prefix + self.root_folder = root_folder + self.definition = definition + self.exc = exc or RuntimeError + self.data_sources = None - def prepare_outputs(self): - '''Prepares the outputs of the dataset''' + def index(self, filename): + '''Index the content of the view''' - from .outputs import Output, OutputList - from .data import MemoryDataSink - from .dataformat import DataFormat + parameters = self.definition.get('parameters', {}) - # create the stock outputs for this dataset, so data is dumped - # on a in-memory sink - self.outputs = OutputList() - for out_name, out_format in self.definition.get('outputs', {}).items(): - data_sink = MemoryDataSink() - data_sink.dataformat = DataFormat(self.prefix, out_format) - data_sink.setup([]) - self.outputs.add(Output(out_name, data_sink, dataset_output=True)) + objs = loader.run(self.obj, 'index', self.exc, self.root_folder, parameters) + if not isinstance(objs, list): + raise self.exc("index() didn't return a list") - def setup(self, *args, **kwargs): - '''Sets up the view''' + if not os.path.exists(os.path.dirname(filename)): + os.makedirs(os.path.dirname(filename)) - kwargs.setdefault('root_folder', self.root_folder) - kwargs.setdefault('parameters', self.definition.get('parameters', {})) + with open(filename, 'wb') as f: + simplejson.dump(objs, f) - if 'outputs' not in kwargs: - kwargs['outputs'] = self.outputs - else: - self.outputs = kwargs['outputs'] #record outputs nevertheless - self.ready = loader.run(self.obj, 'setup', self.exc, *args, **kwargs) + def setup(self, filename, start_index=None, end_index=None, pack=True): + '''Sets up the view''' - if not self.ready: - raise self.exc("unknow setup failure") + if self.ready: + return - return self.ready + with open(filename, 'rb') as f: + objs = simplejson.load(f) + Entry = namedtuple('Entry', sorted(objs[0].keys())) + objs = [ Entry(**x) for x in objs ] - def input_group(self, name='default', exclude_outputs=[]): - '''A memory-source input group matching the outputs from the view''' + parameters = self.definition.get('parameters', {}) - if not self.ready: - raise self.exc("database view not yet setup") + loader.run(self.obj, 'setup', self.exc, self.root_folder, parameters, + objs, start_index=start_index, end_index=end_index) - from .data import MemoryDataSource - from .outputs import SynchronizationListener - from .inputs import Input, InputGroup - # Setup the inputs - synchronization_listener = SynchronizationListener() - input_group = InputGroup(name, - synchronization_listener=synchronization_listener, - restricted_access=False) + # Create data sources for the outputs + from .data import DatabaseOutputDataSource + from .dataformat import DataFormat - for output in self.outputs: - if output.name in exclude_outputs: continue - data_source = MemoryDataSource(self.done, next_callback=self.next) - output.data_sink.data_sources.append(data_source) - input_group.add(Input(output.name, - output.data_sink.dataformat, data_source)) + self.data_sources = {} + for output_name, output_format in self.definition.get('outputs', {}).items(): + data_source = DatabaseOutputDataSource() + data_source.setup(self, output_name, output_format, self.prefix, + start_index=start_index, end_index=end_index, pack=pack) + self.data_sources[output_name] = data_source - return input_group + self.ready = True - def done(self, *args, **kwargs): - '''Checks if the view is done''' + def get(self, output, index): + '''Returns the data of the provided output at the provided index''' if not self.ready: - raise self.exc("database view not yet setup") - - return loader.run(self.obj, 'done', self.exc, *args, **kwargs) - + raise self.exc("Database view not yet setup") - def next(self, *args, **kwargs): - '''Runs through the next data chunk''' - - if not self.ready: - raise self.exc("database view not yet setup") - return loader.run(self.obj, 'next', self.exc, *args, **kwargs) + return loader.run(self.obj, 'get', self.exc, output, index) def __getattr__(self, key): @@ -204,6 +182,10 @@ class View(object): return getattr(self.obj, key) + def objects(self): + return self.obj.objs + + #---------------------------------------------------------- @@ -368,7 +350,9 @@ class Database(object): if not self.valid: message = "cannot load view for set `%s' of protocol `%s' " \ "from invalid database (%s)" % (protocol, name, self.name) - if exc: raise exc(message) + if exc: + raise exc(message) + raise RuntimeError(message) # loads the module only once through the lifetime of the database object @@ -383,8 +367,73 @@ class Database(object): else: raise #just re-raise the user exception - return View(self._module, self.set(protocol, name), self.prefix, - self.data['root_folder'], exc) + return Runner(self._module, self.set(protocol, name), self.prefix, + self.data['root_folder'], exc) + + +#---------------------------------------------------------- + + +class View(object): + + def index(self, root_folder, parameters): + """Returns a list of (named) tuples describing the data provided by the view. + + The ordering of values inside the tuples is free, but it is expected + that the list is ordered in a consistent manner (ie. all train images of + person A, then all train images of person B, ...). + + For instance, assuming a view providing that kind of data: + + ----------- ----------- ----------- ----------- ----------- ----------- + | image | | image | | image | | image | | image | | image | + ----------- ----------- ----------- ----------- ----------- ----------- + ----------- ----------- ----------- ----------- ----------- ----------- + | file_id | | file_id | | file_id | | file_id | | file_id | | file_id | + ----------- ----------- ----------- ----------- ----------- ----------- + ----------------------------------- ----------------------------------- + | client_id | | client_id | + ----------------------------------- ----------------------------------- + + a list like the following should be generated: + + [ + (client_id=1, file_id=1, image=filename1), + (client_id=1, file_id=2, image=filename2), + (client_id=1, file_id=3, image=filename3), + (client_id=2, file_id=4, image=filename4), + (client_id=2, file_id=5, image=filename5), + (client_id=2, file_id=6, image=filename6), + ... + ] + + DO NOT store images, sound files or data loadable from a file in the list! + Store the path of the file to load instead. + """ + + raise NotImplementedError + + + def setup(self, root_folder, parameters, objs, start_index=None, end_index=None): + + # Initialisations + self.root_folder = root_folder + self.parameters = parameters + self.objs = objs + + # Determine the range of indices that must be provided + self.start_index = start_index if start_index is not None else 0 + self.end_index = end_index if end_index is not None else len(self.objs) - 1 + + self.objs = self.objs[self.start_index : self.end_index + 1] + + + def get(self, output, index): + """Returns the data of the provided output at the provided index in the list + of (named) tuples describing the data provided by the view (accessible at + self.objs)""" + + raise NotImplementedError #---------------------------------------------------------- diff --git a/beat/backend/python/dbexecution.py b/beat/backend/python/dbexecution.py index 9b4babd78670a3725d6475b7280b9bdaf8dd9e54..de1b0638f74c6fd1a0098a8b176614e82994e036 100755 --- a/beat/backend/python/dbexecution.py +++ b/beat/backend/python/dbexecution.py @@ -29,11 +29,6 @@ '''Execution utilities''' import os -import sys -import glob -import errno -import tempfile -import subprocess import logging logger = logging.getLogger(__name__) @@ -41,11 +36,8 @@ logger = logging.getLogger(__name__) import simplejson # from . import schema -from . import database -from . import inputs -from . import outputs -from . import data -from . import message_handler +from .database import Database +from .message_handler import MessageHandler class DBExecutor(object): @@ -102,38 +94,24 @@ class DBExecutor(object): """ - def __init__(self, prefix, data, dataformat_cache=None, database_cache=None): + def __init__(self, address, prefix, cache_root, data, dataformat_cache=None, + database_cache=None): + # Initialisations self.prefix = prefix - - # some attributes self.databases = {} self.views = {} - self.input_list = None - self.data_sources = [] - self.handler = None self.errors = [] self.data = None + self.message_handler = None + self.data_sources = {} - # temporary caches, if the user has not set them, for performance + # Temporary caches, if the user has not set them, for performance database_cache = database_cache if database_cache is not None else {} self.dataformat_cache = dataformat_cache if dataformat_cache is not None else {} - self._load(data, database_cache) - - - def _load(self, data, database_cache): - """Loads the block execution information""" - - # reset - self.data = None - self.errors = [] - self.databases = {} - self.views = {} - self.input_list = None - self.data_sources = [] - - if not isinstance(data, dict): #user has passed a file name + # Load the data + if not isinstance(data, dict): # User has passed a file name if not os.path.exists(data): self.errors.append('File not found: %s' % data) return @@ -147,135 +125,79 @@ class DBExecutor(object): # self.data, self.errors = schema.validate('execution', data) # if self.errors: return #don't proceed with the rest of validation - # load databases + # Load the databases for name, details in self.data['inputs'].items(): - if 'database' in details: - - if details['database'] not in self.databases: + if 'database' not in details: + continue - if details['database'] in database_cache: #reuse - db = database_cache[details['database']] - else: #load it - db = database.Database(self.prefix, details['database'], - self.dataformat_cache) - database_cache[db.name] = db + # Load the database + if details['database'] not in self.databases: - self.databases[details['database']] = db + if details['database'] in database_cache: #reuse + db = database_cache[details['database']] + else: #load it + db = Database(self.prefix, details['database'], + self.dataformat_cache) + database_cache[db.name] = db - if not db.valid: - self.errors += db.errors - continue + self.databases[details['database']] = db if not db.valid: - # do not add errors again - continue - - # create and load the required views - key = (details['database'], details['protocol'], details['set']) - if key not in self.views: - view = self.databases[details['database']].view(details['protocol'], - details['set']) - - if details['channel'] == self.data['channel']: #synchronized - start_index, end_index = self.data.get('range', (None, None)) - else: - start_index, end_index = (None, None) - view.prepare_outputs() - self.views[key] = (view, start_index, end_index) - - - def __enter__(self): - """Prepares inputs and outputs for the processing task - - Raises: - - IOError: in case something cannot be properly setup - - """ - - self._prepare_inputs() - - # The setup() of a database view may call isConnected() on an input - # to set the index at the right location when parallelization is enabled. - # This is why setup() should be called after initialized the inputs. - for key, (view, start_index, end_index) in self.views.items(): - - if (start_index is None) and (end_index is None): - status = view.setup() + self.errors += db.errors else: - status = view.setup(force_start_index=start_index, - force_end_index=end_index) + db = self.databases[details['database']] - if not status: - raise RuntimeError("Could not setup database view `%s'" % key) + if not db.valid: + continue - return self + # Create and load the required views + key = (details['database'], details['protocol'], details['set']) + if key not in self.views: + view = db.view(details['protocol'], details['set']) + if details['channel'] == self.data['channel']: #synchronized + start_index, end_index = self.data.get('range', (None, None)) + else: + start_index, end_index = (None, None) - def __exit__(self, exc_type, exc_value, traceback): - """Closes all sinks and disconnects inputs and outputs - """ - self.input_list = None - self.data_sources = [] + view.setup(os.path.join(cache_root, details['path']), + start_index=start_index, end_index=end_index) + self.views[key] = view - def _prepare_inputs(self): - """Prepares all input required by the execution.""" - - self.input_list = inputs.InputList() - - # This is used for parallelization purposes - start_index, end_index = self.data.get('range', (None, None)) - + # Create the data sources for name, details in self.data['inputs'].items(): + if 'database' not in details: + continue - if 'database' in details: #it is a dataset input + view_key = (details['database'], details['protocol'], details['set']) + view = self.views[view_key] - view_key = (details['database'], details['protocol'], details['set']) - view = self.views[view_key][0] + self.data_sources[name] = view.data_sources[details['output']] - data_source = data.MemoryDataSource(view.done, next_callback=view.next) - self.data_sources.append(data_source) - output = view.outputs[details['output']] + # Create the message handler + self.message_handler = MessageHandler(address, data_sources=self.data_sources) - # if it's a synchronized channel, makes the output start at the right - # index, otherwise, it gets lost - if start_index is not None and \ - details['channel'] == self.data['channel']: - output.last_written_data_index = start_index - 1 - output.data_sink.data_sources.append(data_source) - # Synchronization bits - group = self.input_list.group(details['channel']) - if group is None: - group = inputs.InputGroup( - details['channel'], - synchronization_listener=outputs.SynchronizationListener(), - restricted_access=(details['channel'] == self.data['channel']) - ) - self.input_list.add(group) + def process(self): + self.message_handler.start() - input_db = self.databases[details['database']] - input_dataformat_name = input_db.set(details['protocol'], details['set'])['outputs'][details['output']] - group.add(inputs.Input(name, self.dataformat_cache[input_dataformat_name], data_source)) - - def process(self, address): - self.handler = message_handler.MessageHandler(address, inputs=self.input_list) - self.handler.start() + @property + def address(self): + return self.message_handler.address @property def valid(self): """A boolean that indicates if this executor is valid or not""" - return not bool(self.errors) def wait(self): - self.handler.join() - self.handler.destroy() - self.handler = None + self.message_handler.join() + self.message_handler.destroy() + self.message_handler = None def __str__(self): diff --git a/beat/backend/python/hash.py b/beat/backend/python/hash.py index ddea49a7ad8fc2edaa0fbfef48869225d972c227..3e8c020f25f24816a70cef1ddab84ab3cba7debc 100755 --- a/beat/backend/python/hash.py +++ b/beat/backend/python/hash.py @@ -3,7 +3,7 @@ ############################################################################### # # -# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ # +# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ # # Contact: beat.support@idiap.ch # # # # This file is part of the beat.backend.python module of the BEAT platform. # @@ -37,6 +37,9 @@ import six import os +#---------------------------------------------------------- + + def _sha256(s): """A python2/3 replacement for :py:func:`haslib.sha256`""" @@ -47,6 +50,8 @@ def _sha256(s): return hashlib.sha256(s).hexdigest() +#---------------------------------------------------------- + def _stringify(dictionary): names = sorted(dictionary.keys()) @@ -63,12 +68,30 @@ def _stringify(dictionary): return converted_dictionary +#---------------------------------------------------------- + + +def _compact(text): + return text.replace(' ', '').replace('\n', '') + + +#---------------------------------------------------------- + + +def toPath(hash, suffix='.data'): + return os.path.join(hash[0:2], hash[2:4], hash[4:6], hash[6:] + suffix) + + +#---------------------------------------------------------- + def toUserPath(username): hash = _sha256(username) return os.path.join(hash[0:2], hash[2:4], username) +#---------------------------------------------------------- + def hash(dictionary_or_string): if isinstance(dictionary_or_string, dict): @@ -77,6 +100,8 @@ def hash(dictionary_or_string): return _sha256(dictionary_or_string) +#---------------------------------------------------------- + def hashJSON(contents, description): """Hashes the pre-loaded JSON object using :py:func:`hashlib.sha256` @@ -91,6 +116,8 @@ def hashJSON(contents, description): return hashlib.sha256(contents).hexdigest() +#---------------------------------------------------------- + def hashJSONFile(path, description): """Hashes the JSON file contents using :py:func:`hashlib.sha256` @@ -107,6 +134,8 @@ def hashJSONFile(path, description): return hashFileContents(path) +#---------------------------------------------------------- + def hashFileContents(path): """Hashes the file contents using :py:func:`hashlib.sha256`.""" @@ -117,3 +146,15 @@ def hashFileContents(path): sha256.update(chunk) return sha256.hexdigest() + + +#---------------------------------------------------------- + + +def hashDataset(database_name, protocol_name, set_name): + s = _compact("""{ + "database": "%s", + "protocol": "%s", + "set": "%s" +}""") % (database_name, protocol_name, set_name) + return hash(s) diff --git a/beat/backend/python/helpers.py b/beat/backend/python/helpers.py index e82bf19fb23ecb9a999cf86081a654a6115c81e5..e40d595d33f23cafbc49edf77870e72df63ed089 100755 --- a/beat/backend/python/helpers.py +++ b/beat/backend/python/helpers.py @@ -32,9 +32,9 @@ import errno import logging logger = logging.getLogger(__name__) -from .data import MemoryDataSource +from .data import MemoryLegacyDataSource +from .data import CachedLegacyDataSource from .data import CachedDataSource -from .data import CachedFileLoader from .data import CachedDataSink from .data import getAllFilenames from .data_loaders import DataLoaderList @@ -103,7 +103,7 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root, def _create_local_input(details): - data_source = CachedDataSource() + data_source = CachedLegacyDataSource() data_sources.append(data_source) filename = os.path.join(cache_root, details['path'] + '.data') @@ -144,7 +144,7 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root, logger.debug("Data loader created: group='%s'" % details['channel']) - cached_file = CachedFileLoader() + cached_file = CachedDataSource() result = cached_file.setup( filename=filename, prefix=prefix, @@ -193,7 +193,7 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root, view = views[channel] # Creation of the input - data_source = MemoryDataSource(view.done, next_callback=view.next) + data_source = MemoryLegacyDataSource(view.done, next_callback=view.next) output = view.outputs[details['output']] output.data_sink.data_sources.append(data_source) diff --git a/beat/backend/python/inputs.py b/beat/backend/python/inputs.py index 40bbcb302cb117de1db8079870a43ef14524440b..fd0d3e4f74f06a47c5b93bf43ba1532ee67df5c8 100755 --- a/beat/backend/python/inputs.py +++ b/beat/backend/python/inputs.py @@ -36,6 +36,7 @@ import six import zmq from .data import mixDataIndices +from .data import RemoteException #---------------------------------------------------------- @@ -136,13 +137,13 @@ class Input(BaseInput): data_format (str): Data format accepted by the input - data_source (beat.core.platform.data.DataSource): Source of data to be used + data_source (beat.core.platform.data.LegacyDataSource): Source of data to be used by the input Attributes: - data_source (beat.core.data.DataSource): Source of data used by the output + data_source (beat.core.data.LegacyDataSource): Source of data used by the output """ @@ -180,22 +181,6 @@ class Input(BaseInput): #---------------------------------------------------------- -class RemoteException(Exception): - - def __init__(self, kind, message): - super(RemoteException, self).__init__() - - if kind == 'sys': - self.system_error = message - self.user_error = '' - else: - self.system_error = '' - self.user_error = message - - -#---------------------------------------------------------- - - def process_error(socket): kind = socket.recv() message = socket.recv() diff --git a/beat/backend/python/message_handler.py b/beat/backend/python/message_handler.py index 368fb28724b2293593d30e52b4dc8dc39c7ffde1..68e69301d80b4e5e0b5b5a1323c4c3e7b39bf2cf 100755 --- a/beat/backend/python/message_handler.py +++ b/beat/backend/python/message_handler.py @@ -35,13 +35,14 @@ import requests import threading from . import baseformat -from .inputs import RemoteException +from .data import RemoteException class MessageHandler(threading.Thread): '''A 0MQ message handler for our communication with other processes''' - def __init__(self, host_address, inputs=None, outputs=None, kill_callback=None): + def __init__(self, host_address, inputs=None, outputs=None, data_sources=None, + kill_callback=None): super(MessageHandler, self).__init__() @@ -75,6 +76,7 @@ class MessageHandler(threading.Thread): # Initialisations self.input_list = inputs self.output_list = outputs + self.data_sources = data_sources self.system_error = '' self.user_error = '' @@ -100,6 +102,12 @@ class MessageHandler(threading.Thread): oic = self.output_is_connected, )) + if self.data_sources is not None: + self.callbacks.update(dict( + ifo = self.infos, + get = self.get_data, + )) + def destroy(self): self.socket.setsockopt(zmq.LINGER, 0) @@ -334,6 +342,70 @@ class MessageHandler(threading.Thread): self.socket.send(what) + def infos(self, name): + """Syntax: ifo name""" + + logger.debug('recv: ifo %s', name) + + if self.data_sources is None: + message = 'Unexpected message received: ifo %s' % name + raise RemoteException('sys', message) + + try: + data_source = self.data_sources[name] + except: + raise RemoteException('sys', 'Unknown input: %s' % name) + + logger.debug('send: %d infos', len(data_source)) + + self.socket.send('%d' % len(data_source), zmq.SNDMORE) + + for start, end in data_source.data_indices(): + self.socket.send('%d' % start, zmq.SNDMORE) + + if end < data_source.last_data_index(): + self.socket.send('%d' % end, zmq.SNDMORE) + else: + self.socket.send('%d' % end) + + + def get_data(self, name, index): + """Syntax: get name index""" + + logger.debug('recv: get %s %s', name, index) + + if self.data_sources is None: + message = 'Unexpected message received: get %s %s' % (name, index) + raise RemoteException('sys', message) + + try: + data_source = self.data_sources[name] + except: + raise RemoteException('sys', 'Unknown input: %s' % name) + + try: + index = int(index) + except: + raise RemoteException('sys', 'Invalid index: %s' % index) + + (data, start_index, end_index) = data_source[index] + + if data is None: + raise RemoteException('sys', 'Invalid index: %s' % index) + + if isinstance(data, baseformat.baseformat): + packed = data.pack() + else: + packed = data + + logger.debug('send: <bin> (size=%d), indexes=(%d, %d)', len(packed), + start_index, end_index) + + self.socket.send('%d' % start_index, zmq.SNDMORE) + self.socket.send('%d' % end_index, zmq.SNDMORE) + self.socket.send(packed) + + def kill(self): self.must_kill.set() diff --git a/beat/backend/python/test/mocks.py b/beat/backend/python/test/mocks.py index 66bc79673e0dae40a1b4ad565e78270d74cd6467..37545ee89e888504b750a7bee0c743c97765de33 100644 --- a/beat/backend/python/test/mocks.py +++ b/beat/backend/python/test/mocks.py @@ -26,11 +26,11 @@ ############################################################################### -from ..data import DataSource +from ..data import LegacyDataSource from ..data import DataSink -class MockDataSource(DataSource): +class MockLegacyDataSource(LegacyDataSource): def __init__(self, data, indexes): self.data = list(data) @@ -75,7 +75,7 @@ class MockDataSink(DataSink): #---------------------------------------------------------- -class MockDataSource_Crash(DataSource): +class MockLegacyDataSource_Crash(LegacyDataSource): def next(self): a = b diff --git a/beat/backend/python/test/prefix/databases/crash/1.json b/beat/backend/python/test/prefix/databases/crash/1.json index f23b732c05b8df813c8ad9b9bc240fd25a5904f4..e2c7f0ccc0455bff61c4b22b64838fc48b96faa5 100644 --- a/beat/backend/python/test/prefix/databases/crash/1.json +++ b/beat/backend/python/test/prefix/databases/crash/1.json @@ -6,41 +6,17 @@ "template": "template", "sets": [ { - "name": "done_crashes", + "name": "index_crashes", "template": "set", - "view": "DoneCrashes", + "view": "IndexCrashes", "outputs": { "out": "user/single_integer/1" } }, { - "name": "next_crashes", + "name": "get_crashes", "template": "set", - "view": "NextCrashes", - "outputs": { - "out": "user/single_integer/1" - } - }, - { - "name": "setup_crashes", - "template": "set", - "view": "SetupCrashes", - "outputs": { - "out": "user/single_integer/1" - } - }, - { - "name": "setup_fails", - "template": "set", - "view": "SetupFails", - "outputs": { - "out": "user/single_integer/1" - } - }, - { - "name": "does_not_exist", - "template": "set", - "view": "DoesNotExist", + "view": "GetCrashes", "outputs": { "out": "user/single_integer/1" } diff --git a/beat/backend/python/test/prefix/databases/crash/1.py b/beat/backend/python/test/prefix/databases/crash/1.py index a6ac05f0754402ddb67ce999acea905f43ce0727..081cf27ca75fcff06c66d5c08437daa484424522 100755 --- a/beat/backend/python/test/prefix/databases/crash/1.py +++ b/beat/backend/python/test/prefix/databases/crash/1.py @@ -25,30 +25,29 @@ # # ############################################################################### -class DoneCrashes: - def setup(self, *args, **kwargs): return True - def next(self): return True - def done(self, last_data_index): - a = b - return True +from beat.backend.python.database import View +from collections import namedtuple + -class NextCrashes: +class IndexCrashes(View): - def setup(self, *args, **kwargs): return True - def done(self, last_data_index): return False - def next(self): + def index(self, *args, **kwargs): a = b - return True + return [] -class SetupCrashes: + def get(self, *args, **kwargs): + return 0 - def done(self, last_data_index): return True - def next(self): return True - def setup(self, *args, **kwargs): - a = b - return True -class SetupFails: - def setup(self, *args, **kwargs): return False - def done(self, last_data_index): return True - def next(self): return True +#---------------------------------------------------------- + + +class GetCrashes(View): + + def index(self, *args, **kwargs): + Entry = namedtuple('Entry', ['out']) + return [ Entry(1) ] + + def get(self, *args, **kwargs): + a = b + return 0 diff --git a/beat/backend/python/test/prefix/databases/integers_db/1.py b/beat/backend/python/test/prefix/databases/integers_db/1.py index 10388993e117e15fc9f509aedf4f305d9fd95312..0cdafc984460c89a3e27de69ff14d7a5595d5534 100755 --- a/beat/backend/python/test/prefix/databases/integers_db/1.py +++ b/beat/backend/python/test/prefix/databases/integers_db/1.py @@ -27,163 +27,162 @@ import random import numpy +from collections import namedtuple +from beat.backend.python.database import View -class Double: +class Double(View): - def setup(self, root_folder, outputs, parameters): - self.outputs = outputs - random.seed(0) #so it is kept reproducible - return True - - - def done(self, last_data_index): - return (last_data_index == 9) - - - def next(self): - val1 = numpy.int32(random.randint(0, 1000)) - self.outputs['a'].write({ - 'value': val1, - }) - - val2 = numpy.int32(random.randint(0, 1000)) - self.outputs['b'].write({ - 'value': val2, - }) - - self.outputs['sum'].write({ - 'value': val1 + val2, - }) - - return True - - - -class Triple: - - def setup(self, root_folder, outputs, parameters): - random.seed(0) #so it is kept reproducible - self.outputs = outputs - return True + def index(self, root_folder, parameters): + Entry = namedtuple('Entry', ['a', 'b', 'sum']) + return [ + Entry(1, 10, 11), + Entry(2, 20, 22), + Entry(3, 30, 33), + Entry(4, 40, 44), + Entry(5, 50, 55), + Entry(6, 60, 66), + Entry(7, 70, 77), + Entry(8, 80, 88), + Entry(9, 90, 99), + ] - def done(self, last_data_index): - return (last_data_index == 9) + def get(self, output, index): + obj = self.objs[index] - def next(self): - val1 = numpy.int32(random.randint(0, 1000)) - self.outputs['a'].write({ - 'value': val1, - }) + if output == 'a': + return { + 'value': numpy.int32(obj.a) + } - val2 = numpy.int32(random.randint(0, 1000)) - self.outputs['b'].write({ - 'value': val2, - }) + elif output == 'b': + return { + 'value': numpy.int32(obj.b) + } - val3 = numpy.int32(random.randint(0, 1000)) - self.outputs['c'].write({ - 'value': val3, - }) + elif output == 'sum': + return { + 'value': numpy.int32(obj.sum) + } - self.outputs['sum'].write({ - 'value': val1 + val2 + val3, - }) - return True +#---------------------------------------------------------- +class Triple(View): -class Labelled: + def index(self, root_folder, parameters): + Entry = namedtuple('Entry', ['a', 'b', 'c', 'sum']) - def setup(self, root_folder, outputs, parameters): - self.outputs = outputs - self.remaining = [ - ['A', [1, 2, 3, 4, 5]], - ['B', [10, 20, 30, 40, 50]], - ['C', [100, 200, 300, 400, 500]], + return [ + Entry(1, 10, 100, 111), + Entry(2, 20, 200, 222), + Entry(3, 30, 300, 333), + Entry(4, 40, 400, 444), + Entry(5, 50, 500, 555), + Entry(6, 60, 600, 666), + Entry(7, 70, 700, 777), + Entry(8, 80, 800, 888), + Entry(9, 90, 900, 999), ] - self.current_label = None - return True - def done(self, last_data_index): - return (last_data_index == 14) + def get(self, output, index): + obj = self.objs[index] + if output == 'a': + return { + 'value': numpy.int32(obj.a) + } - def next(self): - # Ensure that we are not done - if len(self.remaining) == 0: - return False + elif output == 'b': + return { + 'value': numpy.int32(obj.b) + } - # Retrieve the next label and value - label = self.remaining[0][0] - value = self.remaining[0][1][0] + elif output == 'c': + return { + 'value': numpy.int32(obj.c) + } - # Only write each label once on the output, with the correct range of indexes - if self.current_label != label: - self.outputs['label'].write({ - 'value': label, - }, self.outputs['label'].last_written_data_index + len(self.remaining[0][1])) - self.current_label = label + elif output == 'sum': + return { + 'value': numpy.int32(obj.sum) + } - # Write the value - self.outputs['value'].write({ - 'value': numpy.int32(value), - }) - # Remove the value (and if needed the label) from the list of remaining data - self.remaining[0][1] = self.remaining[0][1][1:] - if len(self.remaining[0][1]) == 0: - self.remaining = self.remaining[1:] +#---------------------------------------------------------- - return True +class Labelled(View): + def index(self, root_folder, parameters): + Entry = namedtuple('Entry', ['label', 'value']) -class DifferentFrequencies: + return [ + Entry('A', 1), + Entry('A', 2), + Entry('A', 3), + Entry('A', 4), + Entry('A', 5), + Entry('B', 10), + Entry('B', 20), + Entry('B', 30), + Entry('B', 40), + Entry('B', 50), + Entry('C', 100), + Entry('C', 200), + Entry('C', 300), + Entry('C', 400), + Entry('C', 500), + ] - def setup(self, root_folder, outputs, parameters): - self.outputs = outputs - self.values_a = [(1, 0, 3), (2, 4, 7)] - self.values_b = [(10, 0, 0), (20, 1, 1), (30, 2, 2), (40, 3, 3), - (50, 4, 4), (60, 5, 5), (70, 6, 6), (80, 7, 7)] - self.next_index = 0 + def get(self, output, index): + obj = self.objs[index] - return True + if output == 'label': + return { + 'value': obj.label + } + elif output == 'value': + return { + 'value': numpy.int32(obj.value) + } - def done(self, last_data_index): - return (last_data_index == 7) +#---------------------------------------------------------- - def next(self): - if self.outputs['b'].isConnected() and \ - (self.outputs['b'].last_written_data_index < self.next_index): - self.outputs['b'].write({ - 'value': numpy.int32(self.values_b[0][0]), - }, - end_data_index=self.values_b[0][2] - ) +class DifferentFrequencies(View): - self.values_b = self.values_b[1:] + def index(self, root_folder, parameters): + Entry = namedtuple('Entry', ['a', 'b']) - if self.outputs['a'].isConnected() and \ - (self.outputs['a'].last_written_data_index < self.next_index): + return [ + Entry(1, 10), + Entry(1, 20), + Entry(1, 30), + Entry(1, 40), + Entry(2, 50), + Entry(2, 60), + Entry(2, 70), + Entry(2, 80), + ] - self.outputs['a'].write({ - 'value': numpy.int32(self.values_a[0][0]), - }, - end_data_index=self.values_a[0][2] - ) - self.values_a = self.values_a[1:] + def get(self, output, index): + obj = self.objs[index] - self.next_index = 1 + min([ x.last_written_data_index for x in self.outputs - if x.isConnected() ]) + if output == 'a': + return { + 'value': numpy.int32(obj.a) + } - return True + elif output == 'b': + return { + 'value': numpy.int32(obj.b) + } diff --git a/beat/backend/python/test/prefix/databases/syntax_error/1.py b/beat/backend/python/test/prefix/databases/syntax_error/1.py old mode 100644 new mode 100755 index 6b16be0da66926be4ba0fe9a02f88baa5d7d307c..c5af2e0ac35593ead68ee8b47622c62a675bf811 --- a/beat/backend/python/test/prefix/databases/syntax_error/1.py +++ b/beat/backend/python/test/prefix/databases/syntax_error/1.py @@ -25,5 +25,10 @@ # # ############################################################################### -class View; # <-- syntax error! - def next(self): return True +from beat.backend.python.database import View + + +class SyntaxError(View); # <-- syntax error! + + def get(self, output, index): + return True diff --git a/beat/backend/python/test/prefix/databases/valid/1.json b/beat/backend/python/test/prefix/databases/valid/1.json deleted file mode 100644 index 606a4a34ff90dd4a446b93f6c8041c0a6e8649da..0000000000000000000000000000000000000000 --- a/beat/backend/python/test/prefix/databases/valid/1.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "root_folder": "/path/not/set", - "protocols": [ - { - "name": "valid", - "template": "valid", - "sets": [ - { - "name": "valid", - "template": "valid", - "view": "View", - "outputs": { - "a": "user/single_integer/1" - } - } - ] - } - ] -} diff --git a/beat/backend/python/test/prefix/databases/valid/1.py b/beat/backend/python/test/prefix/databases/valid/1.py deleted file mode 100755 index a604facb7cffad005697e17c6caae1e2b15e5edd..0000000000000000000000000000000000000000 --- a/beat/backend/python/test/prefix/databases/valid/1.py +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env python -# vim: set fileencoding=utf-8 : - -############################################################################### -# # -# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ # -# Contact: beat.support@idiap.ch # -# # -# This file is part of the beat.core module of the BEAT platform. # -# # -# Commercial License Usage # -# Licensees holding valid commercial BEAT licenses may use this file in # -# accordance with the terms contained in a written agreement between you # -# and Idiap. For further information contact tto@idiap.ch # -# # -# Alternatively, this file may be used under the terms of the GNU Affero # -# Public License version 3 as published by the Free Software and appearing # -# in the file LICENSE.AGPL included in the packaging of this file. # -# The BEAT platform is distributed in the hope that it will be useful, but # -# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # -# or FITNESS FOR A PARTICULAR PURPOSE. # -# # -# You should have received a copy of the GNU Affero Public License along # -# with the BEAT platform. If not, see http://www.gnu.org/licenses/. # -# # -############################################################################### - -class View: - - def setup(self, root_folder, outputs, parameters, force_start_index=None, force_end_index=None): - self.root_folder = root_folder - self.outputs = outputs - - self.must_return_done = False - self.must_return_error = False - - self.force_start_index = force_start_index - self.force_end_index = force_end_index - - return True - - - def done(self, last_data_index): - return self.must_return_done - - - def next(self): - self.must_return_done = True - return not(self.must_return_error) diff --git a/beat/backend/python/test/test_algorithm.py b/beat/backend/python/test/test_algorithm.py index f326f530b24d47d2c99a8613ae8ca112b377e403..8cb1243b0c68315af934b4e91b4dae6f12766a83 100644 --- a/beat/backend/python/test/test_algorithm.py +++ b/beat/backend/python/test/test_algorithm.py @@ -38,8 +38,8 @@ from ..data_loaders import DataLoaderList from ..data_loaders import DataLoader from ..dataformat import DataFormat from ..data import CachedDataSink +from ..data import CachedLegacyDataSource from ..data import CachedDataSource -from ..data import CachedFileLoader from ..inputs import Input from ..inputs import InputGroup from ..inputs import InputList @@ -575,7 +575,7 @@ class TestLegacyAPI_Process(TestExecutionBase): inputs.add(group) for input_name in group_inputs: - data_source = CachedDataSource() + data_source = CachedLegacyDataSource() data_source.setup(self.filenames[input_name], prefix) group.add(Input(input_name, dataformat, data_source)) @@ -770,7 +770,7 @@ class TestSequentialAPI_Process(TestExecutionBase): inputs = InputGroup(group_name, synchronization_listener, True) for input_name in group_inputs: - data_source = CachedDataSource() + data_source = CachedLegacyDataSource() data_source.setup(self.filenames[input_name], prefix) inputs.add(Input(input_name, dataformat, data_source)) @@ -779,7 +779,7 @@ class TestSequentialAPI_Process(TestExecutionBase): data_loaders.add(data_loader) for input_name in group_inputs: - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filenames[input_name], prefix) data_loader.add(input_name, cached_file) @@ -1020,7 +1020,7 @@ class TestAutonomousAPI_Process(TestExecutionBase): data_loaders.add(data_loader) for input_name in group_inputs: - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filenames[input_name], prefix) data_loader.add(input_name, cached_file) diff --git a/beat/backend/python/test/test_data.py b/beat/backend/python/test/test_data.py index 286c0b5fcddf2bbd5be22d2e8c54b8cbd2ead1a0..c8f075fa46809e8bda2a4d41c31cf0394db31454 100644 --- a/beat/backend/python/test/test_data.py +++ b/beat/backend/python/test/test_data.py @@ -30,15 +30,17 @@ import unittest import os import glob import tempfile +import shutil from ..data import mixDataIndices -from ..data import CachedFileLoader -from ..data import CachedDataSink from ..data import CachedDataSource +from ..data import CachedDataSink +from ..data import CachedLegacyDataSource from ..data import getAllFilenames from ..data import foundSplitRanges from ..hash import hashFileContents from ..dataformat import DataFormat +from ..database import Database from . import prefix @@ -222,7 +224,7 @@ class TestGetAllFilenames(TestCachedDataBase): #---------------------------------------------------------- -class TestCachedFileLoader(TestCachedDataBase): +class TestCachedDataSource(TestCachedDataBase): def check_valid_indices(self, cached_file): for i in range(0, len(cached_file)): @@ -259,7 +261,7 @@ class TestCachedFileLoader(TestCachedDataBase): def test_one_complete_data_file(self): self.writeData('user/single_integer/1', 0, 9) - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filename, prefix) self.assertEqual(10, len(cached_file)) @@ -274,7 +276,7 @@ class TestCachedFileLoader(TestCachedDataBase): self.writeData('user/single_integer/1', 10, 19) self.writeData('user/single_integer/1', 20, 29) - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filename, prefix) self.assertEqual(30, len(cached_file)) @@ -287,7 +289,7 @@ class TestCachedFileLoader(TestCachedDataBase): def test_one_partial_data_file(self): self.writeData('user/single_integer/1', 0, 9) - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filename, prefix, 2, 6) self.assertEqual(5, len(cached_file)) @@ -302,7 +304,7 @@ class TestCachedFileLoader(TestCachedDataBase): self.writeData('user/single_integer/1', 10, 19) self.writeData('user/single_integer/1', 20, 29) - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filename, prefix, 14, 18) self.assertEqual(5, len(cached_file)) @@ -317,7 +319,7 @@ class TestCachedFileLoader(TestCachedDataBase): self.writeData('user/single_integer/1', 10, 19) self.writeData('user/single_integer/1', 20, 29) - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filename, prefix, 4, 18) self.assertEqual(15, len(cached_file)) @@ -332,7 +334,7 @@ class TestCachedFileLoader(TestCachedDataBase): self.writeData('user/single_integer/1', 10, 19) self.writeData('user/single_integer/1', 20, 29) - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filename, prefix, 4, 28) self.assertEqual(25, len(cached_file)) @@ -345,6 +347,70 @@ class TestCachedFileLoader(TestCachedDataBase): #---------------------------------------------------------- +class TestDatabaseOutputDataSource(unittest.TestCase): + + def setUp(self): + self.cache_root = tempfile.mkdtemp(prefix=__name__) + + + def tearDown(self): + shutil.rmtree(self.cache_root) + + + def check_valid_indices(self, data_source): + for i in range(0, len(data_source)): + (data, start_index, end_index) = data_source[i] + self.assertTrue(data is not None) + self.assertEqual(i + data_source.first_data_index(), start_index) + self.assertEqual(i + data_source.first_data_index(), end_index) + + + def check_valid_data_indices(self, data_source): + for i in range(0, len(data_source)): + (data, start_index, end_index) = data_source.getAtDataIndex(i + data_source.first_data_index()) + self.assertTrue(data is not None) + self.assertEqual(i + data_source.first_data_index(), start_index) + self.assertEqual(i + data_source.first_data_index(), end_index) + + + def check_invalid_indices(self, data_source): + # Invalid indices + (data, start_index, end_index) = data_source[-1] + self.assertTrue(data is None) + + (data, start_index, end_index) = data_source[len(data_source)] + self.assertTrue(data is None) + + # Invalid data indices + (data, start_index, end_index) = data_source.getAtDataIndex(data_source.first_data_index() - 1) + self.assertTrue(data is None) + + (data, start_index, end_index) = data_source.getAtDataIndex(data_source.last_data_index() + 1) + self.assertTrue(data is None) + + + def test(self): + db = Database(prefix, 'integers_db/1') + self.assertTrue(db.valid) + + view = db.view('double', 'double') + view.index(os.path.join(self.cache_root, 'data.db')) + view.setup(os.path.join(self.cache_root, 'data.db'), pack=False) + + self.assertTrue(view.data_sources is not None) + self.assertEqual(len(view.data_sources), 3) + + for output_name, data_source in view.data_sources.items(): + self.assertEqual(9, len(data_source)) + + self.check_valid_indices(data_source) + self.check_valid_data_indices(data_source) + self.check_invalid_indices(data_source) + + +#---------------------------------------------------------- + + class TestDataSink(TestCachedDataBase): def test_creation(self): @@ -358,12 +424,12 @@ class TestDataSink(TestCachedDataBase): #---------------------------------------------------------- -class TestDataSource(TestCachedDataBase): +class TestLegacyDataSource(TestCachedDataBase): def test_creation(self): self.writeData('user/single_integer/1') - data_source = CachedDataSource() + data_source = CachedLegacyDataSource() self.assertTrue(data_source.setup(self.filename, prefix)) self.assertTrue(data_source.dataformat.valid) @@ -375,7 +441,7 @@ class TestDataSource(TestCachedDataBase): def perform_deserialization(self, dataformat_name, start_index=0, end_index=10): reference = self.writeData(dataformat_name) # Always generate 10 data units - data_source = CachedDataSource() + data_source = CachedLegacyDataSource() self.assertTrue(data_source.setup(self.filename, prefix, force_start_index=start_index, force_end_index=end_index)) diff --git a/beat/backend/python/test/test_data_loaders.py b/beat/backend/python/test/test_data_loaders.py index 4fc011f07c476f27fa0a4b6473c4deb0751d4129..37dbfe2ea995351e4d7b691f6c9024ec93a7f54b 100644 --- a/beat/backend/python/test/test_data_loaders.py +++ b/beat/backend/python/test/test_data_loaders.py @@ -36,7 +36,7 @@ from ..data_loaders import DataLoader from ..data_loaders import DataLoaderList from ..dataformat import DataFormat from ..data import CachedDataSink -from ..data import CachedFileLoader +from ..data import CachedDataSource from . import prefix @@ -115,7 +115,7 @@ class DataLoaderTest(DataLoaderBaseTest): data_loader = DataLoader('channel1') - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filenames['input1'], prefix) data_loader.add('input1', cached_file) @@ -238,11 +238,11 @@ class DataLoaderTest(DataLoaderBaseTest): data_loader = DataLoader('channel1') - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filenames['input1'], prefix) data_loader.add('input1', cached_file) - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filenames['input2'], prefix) data_loader.add('input2', cached_file) @@ -478,11 +478,11 @@ class DataLoaderTest(DataLoaderBaseTest): data_loader = DataLoader('channel1') - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filenames['input1'], prefix) data_loader.add('input1', cached_file) - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filenames['input2'], prefix) data_loader.add('input2', cached_file) @@ -687,7 +687,7 @@ class DataLoaderListTest(DataLoaderBaseTest): data_loader = DataLoader('channel1') - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filenames['input1'], prefix) data_loader.add('input1', cached_file) @@ -709,11 +709,11 @@ class DataLoaderListTest(DataLoaderBaseTest): data_loader = DataLoader('channel1') - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filenames['input1'], prefix) data_loader.add('input1', cached_file) - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filenames['input2'], prefix) data_loader.add('input2', cached_file) @@ -737,17 +737,17 @@ class DataLoaderListTest(DataLoaderBaseTest): data_loader1 = DataLoader('channel1') - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filenames['input1'], prefix) data_loader1.add('input1', cached_file) - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filenames['input2'], prefix) data_loader1.add('input2', cached_file) data_loader2 = DataLoader('channel2') - cached_file = CachedFileLoader() + cached_file = CachedDataSource() cached_file.setup(self.filenames['input3'], prefix) data_loader2.add('input3', cached_file) diff --git a/beat/backend/python/test/test_database_view.py b/beat/backend/python/test/test_database_view.py index 11e2a37d1d125b4479cab0458c0c0ec317aeb622..8f0bcfa14e158073ee5f8201d843d60f3e427d18 100644 --- a/beat/backend/python/test/test_database_view.py +++ b/beat/backend/python/test/test_database_view.py @@ -26,7 +26,10 @@ ############################################################################### -import nose.tools +import unittest +import tempfile +import shutil +import os from ..database import Database @@ -43,167 +46,84 @@ class MyExc(Exception): #---------------------------------------------------------- -@nose.tools.raises(MyExc) -def test_done_crashing_view(): +class TestDatabaseViewRunner(unittest.TestCase): - db = Database(prefix, 'crash/1') - assert db.valid, '\n * %s' % '\n * '.join(db.errors) - view = db.view('protocol', 'done_crashes', MyExc) - view.done() + def setUp(self): + self.cache_root = tempfile.mkdtemp(prefix=__name__) -#---------------------------------------------------------- - - -@nose.tools.raises(MyExc) -def test_not_setup(): - - db = Database(prefix, 'valid/1') - assert db.valid, '\n * %s' % '\n * '.join(db.errors) - view = db.view('valid', 'valid', MyExc) - view.done() - - -#---------------------------------------------------------- - - -def test_not_done(): - - db = Database(prefix, 'valid/1') - assert db.valid, '\n * %s' % '\n * '.join(db.errors) - view = db.view('valid', 'valid', MyExc) - assert view.setup() - assert view.done(-1) is False - - -#---------------------------------------------------------- - - -def test_done(): - - db = Database(prefix, 'valid/1') - assert db.valid, '\n * %s' % '\n * '.join(db.errors) - view = db.view('valid', 'valid', MyExc) - assert view.setup() - assert view.done(-1) is False - - # manually setting property on obj implementing the view for testing purposes - view.obj.must_return_done = True - - assert view.done(-1) - - -#---------------------------------------------------------- - - -@nose.tools.raises(SyntaxError) -def test_load_syntax_error_view(): - - db = Database(prefix, 'syntax_error/1') - assert db.valid, '\n * %s' % '\n * '.join(db.errors) - view = db.view('protocol', 'set') - - -#---------------------------------------------------------- - - -@nose.tools.raises(AttributeError) -def test_load_unknown_view(): - - db = Database(prefix, 'crash/1') - assert db.valid, '\n * %s' % '\n * '.join(db.errors) - view = db.view('protocol', 'does_not_exist') - - -#---------------------------------------------------------- - - -def test_load_valid_view(): - - db = Database(prefix, 'integers_db/1') - assert db.valid - view = db.view('double', 'double') - - -#---------------------------------------------------------- + def tearDown(self): + shutil.rmtree(self.cache_root) -@nose.tools.raises(MyExc) -def test_next_crashing_view(): + def test_syntax_error(self): + db = Database(prefix, 'syntax_error/1') + self.assertTrue(db.valid) - db = Database(prefix, 'crash/1') - assert db.valid, '\n * %s' % '\n * '.join(db.errors) - view = db.view('protocol', 'next_crashes', MyExc) - view.next() + with self.assertRaises(SyntaxError): + view = db.view('protocol', 'set') -#---------------------------------------------------------- + def test_unknown_view(self): + db = Database(prefix, 'integers_db/1') + self.assertTrue(db.valid) + with self.assertRaises(KeyError): + view = db.view('protocol', 'does_not_exist') -@nose.tools.raises(MyExc) -def test_not_setup(): - db = Database(prefix, 'valid/1') - assert db.valid, '\n * %s' % '\n * '.join(db.errors) - view = db.view('valid', 'valid', MyExc) - view.next() + def test_valid_view(self): + db = Database(prefix, 'integers_db/1') + self.assertTrue(db.valid) + view = db.view('double', 'double') + self.assertTrue(view is not None) -#---------------------------------------------------------- + def test_indexing_crash(self): + db = Database(prefix, 'crash/1') + self.assertTrue(db.valid) -def test_failure(): + view = db.view('protocol', 'index_crashes', MyExc) - db = Database(prefix, 'valid/1') - assert db.valid, '\n * %s' % '\n * '.join(db.errors) - view = db.view('valid', 'valid', MyExc) - assert view.setup() - view.obj.must_return_error = True - assert view.next() is False + with self.assertRaises(MyExc): + view.index(os.path.join(self.cache_root, 'data.db')) -#---------------------------------------------------------- + def test_get_crash(self): + db = Database(prefix, 'crash/1') + self.assertTrue(db.valid) + view = db.view('protocol', 'get_crashes', MyExc) + view.index(os.path.join(self.cache_root, 'data.db')) + view.setup(os.path.join(self.cache_root, 'data.db')) -def test_success(): + with self.assertRaises(MyExc): + view.get('a', 0) - db = Database(prefix, 'valid/1') - assert db.valid, '\n * %s' % '\n * '.join(db.errors) - view = db.view('valid', 'valid', MyExc) - assert view.setup() - assert view.next() + def test_not_setup(self): + db = Database(prefix, 'crash/1') + self.assertTrue(db.valid) -#---------------------------------------------------------- + view = db.view('protocol', 'get_crashes', MyExc) + with self.assertRaises(MyExc): + view.get('a', 0) -@nose.tools.raises(MyExc) -def test_setup_crashing_view(): - db = Database(prefix, 'crash/1') - view = db.view('protocol', 'setup_crashes', MyExc) - view.setup() - - -#---------------------------------------------------------- - - -@nose.tools.raises(MyExc) -def test_setup_failing_view(): - - db = Database(prefix, 'crash/1') - view = db.view('protocol', 'setup_fails', MyExc) - assert view.ready is False - assert view.setup() - - -#---------------------------------------------------------- + def test_success(self): + db = Database(prefix, 'integers_db/1') + self.assertTrue(db.valid) + view = db.view('double', 'double', MyExc) + view.index(os.path.join(self.cache_root, 'data.db')) + view.setup(os.path.join(self.cache_root, 'data.db')) -def test_setup_successful_view(): + self.assertTrue(view.data_sources is not None) + self.assertEqual(len(view.data_sources), 3) - db = Database(prefix, 'valid/1') - assert db.valid, '\n * %s' % '\n * '.join(db.errors) - view = db.view('valid', 'valid', MyExc) - assert view.setup() - assert view.ready + for i in range(0, 9): + self.assertEqual(view.get('a', i)['value'], i + 1) + self.assertEqual(view.get('b', i)['value'], (i + 1) * 10) + self.assertEqual(view.get('sum', i)['value'], (i + 1) * 10 + i + 1) diff --git a/beat/backend/python/test/test_dbexecution.py b/beat/backend/python/test/test_dbexecution.py deleted file mode 100644 index f1aca8e3f53fb63a8cd297e54e89de864592115f..0000000000000000000000000000000000000000 --- a/beat/backend/python/test/test_dbexecution.py +++ /dev/null @@ -1,162 +0,0 @@ -#!/usr/bin/env python -# vim: set fileencoding=utf-8 : - -############################################################################### -# # -# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ # -# Contact: beat.support@idiap.ch # -# # -# This file is part of the beat.backend.python module of the BEAT platform. # -# # -# Commercial License Usage # -# Licensees holding valid commercial BEAT licenses may use this file in # -# accordance with the terms contained in a written agreement between you # -# and Idiap. For further information contact tto@idiap.ch # -# # -# Alternatively, this file may be used under the terms of the GNU Affero # -# Public License version 3 as published by the Free Software and appearing # -# in the file LICENSE.AGPL included in the packaging of this file. # -# The BEAT platform is distributed in the hope that it will be useful, but # -# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # -# or FITNESS FOR A PARTICULAR PURPOSE. # -# # -# You should have received a copy of the GNU Affero Public License along # -# with the BEAT platform. If not, see http://www.gnu.org/licenses/. # -# # -############################################################################### - - -# Tests for experiment execution - -import os -import logging -logger = logging.getLogger(__name__) - -import unittest -import zmq - -from ..dbexecution import DBExecutor -from ..inputs import RemoteInput -from ..inputs import InputGroup -from ..database import Database - -from . import prefix - - -#---------------------------------------------------------- - - -CONFIGURATION = { - 'queue': 'queue', - 'inputs': { - 'a': { - 'set': 'double', - 'protocol': 'double', - 'database': 'integers_db/1', - 'output': 'a', - 'path': 'ec/89/e5/6e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55', - 'endpoint': 'a', - 'hash': 'ec89e56e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55', - 'channel': 'integers' - }, - 'b': { - 'set': 'double', - 'protocol': 'double', - 'database': 'integers_db/1', - 'output': 'b', - 'path': '6f/b6/66/68e68476cb24be80fc3cb99f6cc8daa822cd86fb8108ce7476bc261fb8', - 'endpoint': 'b', - 'hash': '6fb66668e68476cb24be80fc3cb99f6cc8daa822cd86fb8108ce7476bc261fb8', - 'channel': 'integers' - } - }, - 'algorithm': 'user/sum/1', - 'parameters': {}, - 'environment': { - 'name': 'Python 2.7', - 'version': '1.2.0' - }, - 'outputs': { - 'sum': { - 'path': '20/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681', - 'endpoint': 'sum', - 'hash': '2061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681', - 'channel': 'integers' - } - }, - 'nb_slots': 1, - 'channel': 'integers' -} - - -#---------------------------------------------------------- - - -class HostSide(object): - - def __init__(self, zmq_context): - - # 0MQ server - self.socket = zmq_context.socket(zmq.PAIR) - self.address = 'tcp://127.0.0.1' - port = self.socket.bind_to_random_port(self.address, min_port=50000) - self.address += ':%d' % port - - database = Database(prefix, 'integers_db/1') - - # Creation of the inputs - input_a_conf = CONFIGURATION['inputs']['a'] - dataformat_name_a = database.set(input_a_conf['protocol'], input_a_conf['set'])['outputs']['a'] - self.input_a = RemoteInput('a', database.dataformats[dataformat_name_a], self.socket) - - input_b_conf = CONFIGURATION['inputs']['b'] - dataformat_name_b = database.set(input_b_conf['protocol'], input_b_conf['set'])['outputs']['b'] - self.input_b = RemoteInput('b', database.dataformats[dataformat_name_b], self.socket) - - self.group = InputGroup('integers', restricted_access=False) - self.group.add(self.input_a) - self.group.add(self.input_b) - - -#---------------------------------------------------------- - - -class ContainerSide(object): - - def __init__(self, address): - - dataformat_cache = {} - database_cache = {} - - self.dbexecutor = DBExecutor(prefix, CONFIGURATION, - dataformat_cache, database_cache) - - assert self.dbexecutor.valid, '\n * %s' % '\n * '.join(self.dbexecutor.errors) - - with self.dbexecutor: - self.dbexecutor.process(address) - - - def wait(self): - self.dbexecutor.wait() - - -#---------------------------------------------------------- - - -class TestExecution(unittest.TestCase): - - def test_success(self): - - context = zmq.Context() - - host = HostSide(context) - container = ContainerSide(host.address) - - while host.group.hasMoreData(): - host.group.next() - - host.socket.send('don') - - container.wait() - diff --git a/beat/backend/python/test/test_dbexecutor.py b/beat/backend/python/test/test_dbexecutor.py new file mode 100644 index 0000000000000000000000000000000000000000..027fd34cb772621b9272ab24f56e9e3744d8f696 --- /dev/null +++ b/beat/backend/python/test/test_dbexecutor.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +############################################################################### +# # +# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ # +# Contact: beat.support@idiap.ch # +# # +# This file is part of the beat.backend.python module of the BEAT platform. # +# # +# Commercial License Usage # +# Licensees holding valid commercial BEAT licenses may use this file in # +# accordance with the terms contained in a written agreement between you # +# and Idiap. For further information contact tto@idiap.ch # +# # +# Alternatively, this file may be used under the terms of the GNU Affero # +# Public License version 3 as published by the Free Software and appearing # +# in the file LICENSE.AGPL included in the packaging of this file. # +# The BEAT platform is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # +# or FITNESS FOR A PARTICULAR PURPOSE. # +# # +# You should have received a copy of the GNU Affero Public License along # +# with the BEAT platform. If not, see http://www.gnu.org/licenses/. # +# # +############################################################################### + + +# Tests for experiment execution + +import os +import logging +logger = logging.getLogger(__name__) + +import unittest +import zmq +import tempfile +import shutil + +from ..dbexecution import DBExecutor +from ..database import Database +from ..data_loaders import DataLoader +from ..data import RemoteDataSource +from ..hash import hashDataset +from ..hash import toPath + +from . import prefix + + +#---------------------------------------------------------- + + +DB_VIEW_HASH = hashDataset('integers_db/1', 'double', 'double') +DB_INDEX_PATH = toPath(DB_VIEW_HASH, suffix='.db') + +CONFIGURATION = { + 'queue': 'queue', + 'algorithm': 'user/sum/1', + 'nb_slots': 1, + 'channel': 'integers', + 'parameters': { + }, + 'environment': { + 'name': 'Python 2.7', + 'version': '1.2.0' + }, + 'inputs': { + 'a': { + 'database': 'integers_db/1', + 'protocol': 'double', + 'set': 'double', + 'output': 'a', + 'endpoint': 'a', + 'channel': 'integers', + 'path': DB_INDEX_PATH, + 'hash': DB_VIEW_HASH, + }, + 'b': { + 'database': 'integers_db/1', + 'protocol': 'double', + 'set': 'double', + 'output': 'b', + 'endpoint': 'b', + 'channel': 'integers', + 'path': DB_INDEX_PATH, + 'hash': DB_VIEW_HASH, + } + }, + 'outputs': { + 'sum': { + 'endpoint': 'sum', + 'channel': 'integers', + 'path': '20/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681', + 'hash': '2061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681', + } + } +} + + +#---------------------------------------------------------- + + +class TestExecution(unittest.TestCase): + + def setUp(self): + self.cache_root = tempfile.mkdtemp(prefix=__name__) + + database = Database(prefix, 'integers_db/1') + view = database.view('double', 'double') + + view.index(os.path.join(self.cache_root, DB_INDEX_PATH)) + + self.db_executor = None + self.client_context = None + self.client_socket = None + + + def tearDown(self): + if self.client_socket is not None: + self.client_socket.send('don') + + if self.db_executor is not None: + self.db_executor.wait() + + if self.client_socket is not None: + self.client_socket.setsockopt(zmq.LINGER, 0) + self.client_socket.close() + self.client_context.destroy() + + shutil.rmtree(self.cache_root) + + + def test_success(self): + self.db_executor = DBExecutor('127.0.0.1', prefix, self.cache_root, CONFIGURATION) + + self.assertTrue(self.db_executor.valid) + + self.db_executor.process() + + self.client_context = zmq.Context() + self.client_socket = self.client_context.socket(zmq.PAIR) + self.client_socket.connect(self.db_executor.address) + + data_loader = DataLoader(CONFIGURATION['channel']) + + database = Database(prefix, 'integers_db/1') + + for input_name, input_conf in CONFIGURATION['inputs'].items(): + dataformat_name = database.set(input_conf['protocol'], input_conf['set'])['outputs'][input_conf['output']] + + data_source = RemoteDataSource() + data_source.setup(self.client_socket, input_name, dataformat_name, prefix) + data_loader.add(input_name, data_source) + + + self.assertEqual(data_loader.count('a'), 9) + self.assertEqual(data_loader.count('b'), 9) diff --git a/beat/backend/python/test/test_executor.py b/beat/backend/python/test/test_executor.py index a155c4a022de24a648a1b2de1ce9aa0ed96d3a15..0584deea8558b73cef27cf79a47b950d095fb2e4 100644 --- a/beat/backend/python/test/test_executor.py +++ b/beat/backend/python/test/test_executor.py @@ -41,7 +41,7 @@ from ..inputs import InputList from ..algorithm import Algorithm from ..dataformat import DataFormat from ..data import CachedDataSink -from ..data import CachedFileLoader +from ..data import CachedDataSource from ..helpers import convert_experiment_configuration_to_container from ..helpers import create_inputs_from_configuration from ..helpers import create_outputs_from_configuration @@ -174,7 +174,7 @@ class TestExecutor(unittest.TestCase): for output in output_list: output.close() - cached_file = CachedFileLoader() + cached_file = CachedDataSource() self.assertTrue(cached_file.setup(os.path.join(self.cache_root, CONFIGURATION['outputs']['out']['path'] + '.data'), prefix)) for i in range(len(cached_file)): diff --git a/beat/backend/python/test/test_in_memory_data.py b/beat/backend/python/test/test_in_memory_data.py index e5caa1b3717a4ca85acd55c14430a96d915c4504..31a67d2651c233e95e7ae8bd789fce7fc084c12a 100644 --- a/beat/backend/python/test/test_in_memory_data.py +++ b/beat/backend/python/test/test_in_memory_data.py @@ -31,7 +31,7 @@ import numpy import nose.tools from ..data import MemoryDataSink -from ..data import MemoryDataSource +from ..data import MemoryLegacyDataSource from ..dataformat import DataFormat from . import prefix @@ -50,7 +50,7 @@ def test_one_sink_to_one_source(): def _done_callback(last_data_index): return (last_data_index == 4) - data_source = MemoryDataSource(_done_callback) + data_source = MemoryLegacyDataSource(_done_callback) nose.tools.eq_(len(data_source.data), 0) assert data_source.hasMoreData() @@ -108,7 +108,7 @@ def test_source_callback(): data = dataformat.type() data_sink.write(data, 0, 0) - data_source = MemoryDataSource(_done_callback, next_callback=_next_callback) + data_source = MemoryLegacyDataSource(_done_callback, next_callback=_next_callback) nose.tools.eq_(len(data_source.data), 0) assert data_source.hasMoreData() diff --git a/beat/backend/python/test/test_inputs.py b/beat/backend/python/test/test_inputs.py index 04dc89f5ecc4c0233f99feac1437a380c2ba8702..64bb91c25611692c7ab5d1aa6a1114ab8299f3f9 100644 --- a/beat/backend/python/test/test_inputs.py +++ b/beat/backend/python/test/test_inputs.py @@ -28,7 +28,7 @@ import unittest -from .mocks import MockDataSource +from .mocks import MockLegacyDataSource from ..outputs import SynchronizationListener from ..inputs import Input @@ -46,7 +46,7 @@ from . import prefix class InputTest(unittest.TestCase): def test_creation(self): - data_source = MockDataSource([], []) + data_source = MockLegacyDataSource([], []) input = Input('test', 'mock', data_source) @@ -62,7 +62,7 @@ class InputTest(unittest.TestCase): def test_no_more_data(self): - data_source = MockDataSource([], []) + data_source = MockLegacyDataSource([], []) input = Input('test', 'mock', data_source) @@ -72,7 +72,7 @@ class InputTest(unittest.TestCase): def test_has_more_data(self): dataformat = DataFormat(prefix, 'user/single_integer/1') - data_source = MockDataSource([ dataformat.type(value=10) ], [ (0, 0) ]) + data_source = MockLegacyDataSource([ dataformat.type(value=10) ], [ (0, 0) ]) input = Input('test', 'mock', data_source) @@ -84,7 +84,7 @@ class InputTest(unittest.TestCase): dataformat = DataFormat(prefix, 'user/single_integer/1') - data_source = MockDataSource([ dataformat.type(value=10), dataformat.type(value=20) ], + data_source = MockLegacyDataSource([ dataformat.type(value=10), dataformat.type(value=20) ], [ (0, 0), (1, 1) ]) input = Input('test', 'mock', data_source) @@ -114,7 +114,7 @@ class InputTest(unittest.TestCase): dataformat = DataFormat(prefix, 'user/single_integer/1') - data_source = MockDataSource([ dataformat.type(value=10) ], [ (0, 0) ]) + data_source = MockLegacyDataSource([ dataformat.type(value=10) ], [ (0, 0) ]) input = Input('test', 'mock', data_source) group.add(input) @@ -142,7 +142,7 @@ class InputTest(unittest.TestCase): dataformat = DataFormat(prefix, 'user/single_integer/1') - data_source = MockDataSource([ dataformat.type(value=10) ], [ (0, 0) ]) + data_source = MockLegacyDataSource([ dataformat.type(value=10) ], [ (0, 0) ]) input = Input('test', 'mock', data_source) group.add(input) @@ -174,7 +174,7 @@ class RestrictedInputGroupTest(unittest.TestCase): def test_add_one_input(self): - data_source = MockDataSource([], []) + data_source = MockLegacyDataSource([], []) input = Input('input1', 'mock', data_source) @@ -188,10 +188,10 @@ class RestrictedInputGroupTest(unittest.TestCase): def test_add_two_inputs(self): - data_source1 = MockDataSource([], []) + data_source1 = MockLegacyDataSource([], []) input1 = Input('input1', 'mock', data_source1) - data_source2 = MockDataSource([], []) + data_source2 = MockLegacyDataSource([], []) input2 = Input('input2', 'mock', data_source2) group = InputGroup('channel1') @@ -207,10 +207,10 @@ class RestrictedInputGroupTest(unittest.TestCase): def test_no_more_data(self): - data_source1 = MockDataSource([], []) + data_source1 = MockLegacyDataSource([], []) input1 = Input('input1', 'mock', data_source1) - data_source2 = MockDataSource([], []) + data_source2 = MockLegacyDataSource([], []) input2 = Input('input2', 'mock', data_source2) group = InputGroup('channel1') @@ -224,10 +224,10 @@ class RestrictedInputGroupTest(unittest.TestCase): def test_has_more_data(self): dataformat = DataFormat(prefix, 'user/single_integer/1') - data_source1 = MockDataSource([ dataformat.type(value=10) ], [ (0, 0) ]) + data_source1 = MockLegacyDataSource([ dataformat.type(value=10) ], [ (0, 0) ]) input1 = Input('input1', 'mock', data_source1) - data_source2 = MockDataSource([ dataformat.type(value=10) ], [ (0, 0) ]) + data_source2 = MockLegacyDataSource([ dataformat.type(value=10) ], [ (0, 0) ]) input2 = Input('input2', 'mock', data_source2) group = InputGroup('channel1') @@ -282,7 +282,7 @@ class InputGroupTest(unittest.TestCase): def test_add_one_input(self): - data_source = MockDataSource([], []) + data_source = MockLegacyDataSource([], []) input = Input('input1', 'mock', data_source) @@ -296,10 +296,10 @@ class InputGroupTest(unittest.TestCase): def test_add_two_inputs(self): - data_source1 = MockDataSource([], []) + data_source1 = MockLegacyDataSource([], []) input1 = Input('input1', 'mock', data_source1) - data_source2 = MockDataSource([], []) + data_source2 = MockLegacyDataSource([], []) input2 = Input('input2', 'mock', data_source2) group = InputGroup('channel1', restricted_access=False) @@ -315,10 +315,10 @@ class InputGroupTest(unittest.TestCase): def test_no_more_data(self): - data_source1 = MockDataSource([], []) + data_source1 = MockLegacyDataSource([], []) input1 = Input('input1', 'mock', data_source1) - data_source2 = MockDataSource([], []) + data_source2 = MockLegacyDataSource([], []) input2 = Input('input2', 'mock', data_source2) group = InputGroup('channel1', restricted_access=False) @@ -332,10 +332,10 @@ class InputGroupTest(unittest.TestCase): def test_has_more_data(self): dataformat = DataFormat(prefix, 'user/single_integer/1') - data_source1 = MockDataSource([ dataformat.type(value=10) ], [ (0, 0) ]) + data_source1 = MockLegacyDataSource([ dataformat.type(value=10) ], [ (0, 0) ]) input1 = Input('input1', 'mock', data_source1) - data_source2 = MockDataSource([ dataformat.type(value=10) ], [ (0, 0) ]) + data_source2 = MockLegacyDataSource([ dataformat.type(value=10) ], [ (0, 0) ]) input2 = Input('input2', 'mock', data_source2) group = InputGroup('channel1', restricted_access=False) @@ -356,7 +356,7 @@ class InputGroupTest(unittest.TestCase): dataformat = DataFormat(prefix, 'user/single_integer/1') - data_source = MockDataSource([ dataformat.type(value=x) for x in range(0, len(indices)) ], + data_source = MockLegacyDataSource([ dataformat.type(value=x) for x in range(0, len(indices)) ], indices) input = Input('input', 'mock', data_source) @@ -402,15 +402,15 @@ class InputGroupTest(unittest.TestCase): dataformat = DataFormat(prefix, 'user/single_integer/1') - data_source1 = MockDataSource([ dataformat.type(value=x) for x in range(0, len(indices)) ], + data_source1 = MockLegacyDataSource([ dataformat.type(value=x) for x in range(0, len(indices)) ], indices) input1 = Input('input1', 'mock', data_source1) - data_source2 = MockDataSource([ dataformat.type(value=x + 10) for x in range(0, len(indices)) ], + data_source2 = MockLegacyDataSource([ dataformat.type(value=x + 10) for x in range(0, len(indices)) ], indices) input2 = Input('input2', 'mock', data_source2) - data_source3 = MockDataSource([ dataformat.type(value=x + 20) for x in range(0, len(indices)) ], + data_source3 = MockLegacyDataSource([ dataformat.type(value=x + 20) for x in range(0, len(indices)) ], indices) input3 = Input('input3', 'mock', data_source3) @@ -498,15 +498,15 @@ class InputGroupTest(unittest.TestCase): dataformat = DataFormat(prefix, 'user/single_integer/1') - data_source1 = MockDataSource([ dataformat.type(value=x) for x in range(0, len(indices1)) ], + data_source1 = MockLegacyDataSource([ dataformat.type(value=x) for x in range(0, len(indices1)) ], indices1) input1 = Input('input1', 'mock', data_source1) - data_source2 = MockDataSource([ dataformat.type(value=x + 10) for x in range(0, len(indices2)) ], + data_source2 = MockLegacyDataSource([ dataformat.type(value=x + 10) for x in range(0, len(indices2)) ], indices2) input2 = Input('input2', 'mock', data_source2) - data_source3 = MockDataSource([ dataformat.type(value=x + 20) for x in range(0, len(indices3)) ], + data_source3 = MockLegacyDataSource([ dataformat.type(value=x + 20) for x in range(0, len(indices3)) ], indices3) input3 = Input('input3', 'mock', data_source3) @@ -596,11 +596,11 @@ class InputGroupTest(unittest.TestCase): dataformat = DataFormat(prefix, 'user/single_integer/1') - data_source1 = MockDataSource([ dataformat.type(value=x) for x in range(0, len(indices1)) ], + data_source1 = MockLegacyDataSource([ dataformat.type(value=x) for x in range(0, len(indices1)) ], indices1) input1 = Input('input1', 'mock', data_source1) - data_source2 = MockDataSource([ dataformat.type(value=x) for x in range(0, len(indices2)) ], + data_source2 = MockLegacyDataSource([ dataformat.type(value=x) for x in range(0, len(indices2)) ], indices2) input2 = Input('input2', 'mock', data_source2) @@ -668,11 +668,11 @@ class InputGroupTest(unittest.TestCase): dataformat = DataFormat(prefix, 'user/single_integer/1') - data_source1 = MockDataSource([ dataformat.type(value=x) for x in range(0, len(indices1)) ], + data_source1 = MockLegacyDataSource([ dataformat.type(value=x) for x in range(0, len(indices1)) ], indices1) input1 = Input('input1', 'mock', data_source1) - data_source2 = MockDataSource([ dataformat.type(value=x) for x in range(0, len(indices2)) ], + data_source2 = MockLegacyDataSource([ dataformat.type(value=x) for x in range(0, len(indices2)) ], indices2) input2 = Input('input2', 'mock', data_source2) @@ -750,7 +750,7 @@ class InputListTest(unittest.TestCase): def test_list_one_group_one_input(self): inputs = InputList() - data_source = MockDataSource([], []) + data_source = MockLegacyDataSource([], []) input = Input('input1', 'mock', data_source) group = InputGroup('channel1') @@ -771,10 +771,10 @@ class InputListTest(unittest.TestCase): def test_list_one_group_two_inputs(self): inputs = InputList() - data_source1 = MockDataSource([], []) + data_source1 = MockLegacyDataSource([], []) input1 = Input('input1', 'mock', data_source1) - data_source2 = MockDataSource([], []) + data_source2 = MockLegacyDataSource([], []) input2 = Input('input2', 'mock', data_source2) group = InputGroup('channel1') @@ -803,11 +803,11 @@ class InputListTest(unittest.TestCase): group1 = InputGroup('channel1') - data_source1 = MockDataSource([], []) + data_source1 = MockLegacyDataSource([], []) input1 = Input('input1', 'mock', data_source1) group1.add(input1) - data_source2 = MockDataSource([], []) + data_source2 = MockLegacyDataSource([], []) input2 = Input('input2', 'mock', data_source2) group1.add(input2) @@ -815,7 +815,7 @@ class InputListTest(unittest.TestCase): group2 = InputGroup('channel2', restricted_access=False) - data_source3 = MockDataSource([], []) + data_source3 = MockLegacyDataSource([], []) input3 = Input('input3', 'mock', data_source3) group2.add(input3) diff --git a/beat/backend/python/test/test_message_handler.py b/beat/backend/python/test/test_message_handler.py index 2cfb3d2b464d87bfde913dda83f7688f2ab7ca91..bf03cbff955a2eb00684fccedee798a70c9cb47f 100644 --- a/beat/backend/python/test/test_message_handler.py +++ b/beat/backend/python/test/test_message_handler.py @@ -31,17 +31,25 @@ logger = logging.getLogger(__name__) import unittest import zmq +import os +import glob +import tempfile +import numpy as np from ..message_handler import MessageHandler from ..dataformat import DataFormat from ..inputs import RemoteInput -from ..inputs import RemoteException from ..inputs import Input from ..inputs import InputGroup from ..inputs import InputList +from ..data import RemoteException +from ..data import CachedDataSource +from ..data import RemoteDataSource +from ..data import CachedDataSink +from ..data_loaders import DataLoader -from .mocks import MockDataSource -from .mocks import MockDataSource_Crash +from .mocks import MockLegacyDataSource +from .mocks import MockLegacyDataSource_Crash from . import prefix @@ -51,6 +59,31 @@ from . import prefix class TestMessageHandlerBase(unittest.TestCase): + def setUp(self): + self.filenames = [] + self.data_loader = None + + + def tearDown(self): + for filename in self.filenames: + basename, ext = os.path.splitext(filename) + filenames = [filename] + filenames += glob.glob(basename + '*') + for filename in filenames: + if os.path.exists(filename): + os.unlink(filename) + + self.message_handler.kill() + self.message_handler.join() + self.message_handler = None + + self.client_socket.setsockopt(zmq.LINGER, 0) + self.client_socket.close() + self.client_context.destroy() + + self.data_loader = None + + def create_remote_inputs(self, dataformat, data_sources): group = InputGroup('channel', restricted_access=False) @@ -65,13 +98,13 @@ class TestMessageHandlerBase(unittest.TestCase): self.client_context = zmq.Context() - client_socket = self.client_context.socket(zmq.PAIR) - client_socket.connect(self.message_handler.address) + self.client_socket = self.client_context.socket(zmq.PAIR) + self.client_socket.connect(self.message_handler.address) self.remote_group = InputGroup('channel', restricted_access=False) for name in data_sources.keys(): - remote_input = RemoteInput(name, dataformat, client_socket) + remote_input = RemoteInput(name, dataformat, self.client_socket) self.remote_group.add(remote_input) self.remote_input_list = InputList() @@ -80,10 +113,56 @@ class TestMessageHandlerBase(unittest.TestCase): self.message_handler.start() - def tearDown(self): - self.message_handler.kill() - self.message_handler.join() - self.message_handler = None + def create_data_loader(self, data_sources): + self.message_handler = MessageHandler('127.0.0.1', data_sources=data_sources) + + self.client_context = zmq.Context() + self.client_socket = self.client_context.socket(zmq.PAIR) + self.client_socket.connect(self.message_handler.address) + + self.message_handler.start() + + self.data_loader = DataLoader('channel') + + for input_name in data_sources.keys(): + data_source = RemoteDataSource() + data_source.setup(self.client_socket, input_name, 'user/single_integer/1', prefix) + self.data_loader.add(input_name, data_source) + + + def writeData(self, start_index=0, end_index=10, step=1, base=0): + testfile = tempfile.NamedTemporaryFile(prefix=__name__, suffix='.data') + testfile.close() # preserve only the name + filename = testfile.name + + self.filenames.append(filename) + + dataformat = DataFormat(prefix, 'user/single_integer/1') + self.assertTrue(dataformat.valid) + + data_sink = CachedDataSink() + self.assertTrue(data_sink.setup(filename, dataformat, start_index, end_index)) + + index = start_index + while index + step - 1 <= end_index: + data = dataformat.type() + data.value = np.int32(index + base) + data_sink.write(data, index, index + step - 1) + index += step + + (nb_bytes, duration) = data_sink.statistics() + self.assertTrue(nb_bytes > 0) + self.assertTrue(duration > 0) + + data_sink.close() + del data_sink + + cached_file = CachedDataSource() + cached_file.setup(filename, prefix) + + self.assertTrue(len(cached_file.infos) > 0) + + return cached_file #---------------------------------------------------------- @@ -92,12 +171,14 @@ class TestMessageHandlerBase(unittest.TestCase): class TestOneInput(TestMessageHandlerBase): def setUp(self): + super(TestOneInput, self).setUp() + dataformat = DataFormat(prefix, 'user/single_integer/1') self.create_remote_inputs( DataFormat(prefix, 'user/single_integer/1'), dict( - a = MockDataSource([ + a = MockLegacyDataSource([ dataformat.type(value=10), dataformat.type(value=20), ], @@ -173,15 +254,41 @@ class TestOneInput(TestMessageHandlerBase): #---------------------------------------------------------- +class TestOneDataSource(TestMessageHandlerBase): + + def setUp(self): + super(TestOneDataSource, self).setUp() + + data_sources = {} + data_sources['a'] = self.writeData(start_index=0, end_index=9) + + self.create_data_loader(data_sources) + + + def test_iteration(self): + self.assertEqual(self.data_loader.count('a'), 10) + + for i in range(10): + (result, start, end) = self.data_loader[i] + self.assertEqual(start, i) + self.assertEqual(end, i) + self.assertEqual(result['a'].value, i) + + +#---------------------------------------------------------- + + class TestSameFrequencyInputs(TestMessageHandlerBase): def setUp(self): + super(TestSameFrequencyInputs, self).setUp() + dataformat = DataFormat(prefix, 'user/single_integer/1') self.create_remote_inputs( DataFormat(prefix, 'user/single_integer/1'), dict( - a = MockDataSource([ + a = MockLegacyDataSource([ dataformat.type(value=10), dataformat.type(value=20), ], @@ -190,7 +297,7 @@ class TestSameFrequencyInputs(TestMessageHandlerBase): (1, 1), ] ), - b = MockDataSource([ + b = MockLegacyDataSource([ dataformat.type(value=100), dataformat.type(value=200), ], @@ -258,15 +365,44 @@ class TestSameFrequencyInputs(TestMessageHandlerBase): #---------------------------------------------------------- +class TestSameFrequencyDataSources(TestMessageHandlerBase): + + def setUp(self): + super(TestSameFrequencyDataSources, self).setUp() + + data_sources = {} + data_sources['a'] = self.writeData(start_index=0, end_index=9) + data_sources['b'] = self.writeData(start_index=0, end_index=9, base=10) + + self.create_data_loader(data_sources) + + + def test_iteration(self): + self.assertEqual(self.data_loader.count('a'), 10) + self.assertEqual(self.data_loader.count('b'), 10) + + for i in range(10): + (result, start, end) = self.data_loader[i] + self.assertEqual(start, i) + self.assertEqual(end, i) + self.assertEqual(result['a'].value, i) + self.assertEqual(result['b'].value, 10 + i) + + +#---------------------------------------------------------- + + class TestDifferentFrequenciesInputs(TestMessageHandlerBase): def setUp(self): + super(TestDifferentFrequenciesInputs, self).setUp() + dataformat = DataFormat(prefix, 'user/single_integer/1') self.create_remote_inputs( DataFormat(prefix, 'user/single_integer/1'), dict( - a = MockDataSource([ + a = MockLegacyDataSource([ dataformat.type(value=10), dataformat.type(value=20), ], @@ -275,7 +411,7 @@ class TestDifferentFrequenciesInputs(TestMessageHandlerBase): (4, 7), ] ), - b = MockDataSource([ + b = MockLegacyDataSource([ dataformat.type(value=100), dataformat.type(value=200), dataformat.type(value=300), @@ -347,12 +483,45 @@ class TestDifferentFrequenciesInputs(TestMessageHandlerBase): #---------------------------------------------------------- +class TestDifferentFrequenciesDataSources(TestMessageHandlerBase): + + def setUp(self): + super(TestDifferentFrequenciesDataSources, self).setUp() + + data_sources = {} + data_sources['a'] = self.writeData(start_index=0, end_index=9) + data_sources['b'] = self.writeData(start_index=0, end_index=9, base=10, step=5) + + self.create_data_loader(data_sources) + + + def test_iteration(self): + self.assertEqual(self.data_loader.count('a'), 10) + self.assertEqual(self.data_loader.count('b'), 2) + + for i in range(10): + (result, start, end) = self.data_loader[i] + self.assertEqual(start, i) + self.assertEqual(end, i) + self.assertEqual(result['a'].value, i) + + if i < 5: + self.assertEqual(result['b'].value, 10) + else: + self.assertEqual(result['b'].value, 15) + + +#---------------------------------------------------------- + + class TestMessageHandlerErrorHandling(unittest.TestCase): def setUp(self): + super(TestMessageHandlerErrorHandling, self).setUp() + dataformat = DataFormat(prefix, 'user/single_integer/1') - data_source = MockDataSource_Crash() + data_source = MockLegacyDataSource_Crash() input = Input('in', 'user/single_integer/1', data_source)