diff --git a/beat/core/agent.py b/beat/core/agent.py index 2c3fafca0ccd1d77002a1971745ee21f20919f41..5f90f8c130612e46fdc0003f0c01568bf9d83c53 100755 --- a/beat/core/agent.py +++ b/beat/core/agent.py @@ -44,251 +44,7 @@ from . import utils from . import dock from . import baseformat - -class MessageHandler(gevent.Greenlet): - '''A 0MQ message handler for our communication with other processes - - Support for more messages can be implemented by subclassing this class. - This one only support input-related messages. - ''' - - def __init__(self, input_list, zmq_context, zmq_socket): - - super(MessageHandler, self).__init__() - - # An event unblocking a graceful stop - self.stop = gevent.event.Event() - self.stop.clear() - - self.must_kill = gevent.event.Event() - self.must_kill.clear() - - # Starts our 0MQ server - self.context = zmq_context - self.socket = zmq_socket - - self.poller = zmq.Poller() - self.poller.register(self.socket, zmq.POLLIN) - - self.input_list = input_list - - self.system_error = '' - self.user_error = '' - self.last_statistics = {} - self.process = None - - # implementations - self.callbacks = dict( - nxt = self.next, - hmd = self.has_more_data, - idd = self.is_dataunit_done, - don = self.done, - err = self.error, - ) - - - def set_process(self, process): - self.process = process - self.process.statistics() # initialize internal statistics - - - def _run(self): - - logger.debug("0MQ server thread started") - - while not self.stop.is_set(): #keep on - - if self.must_kill.is_set(): - if self.process is not None: - self.process.kill() - self.must_kill.clear() - - timeout = 1000 #ms - socks = dict(self.poller.poll(timeout)) #yields to the next greenlet - - if self.socket in socks and socks[self.socket] == zmq.POLLIN: - - # incomming - more = True - parts = [] - while more: - parts.append(self.socket.recv()) - more = self.socket.getsockopt(zmq.RCVMORE) - command = parts[0] - - logger.debug("recv: %s", command) - - if command in self.callbacks: - try: #to handle command - self.callbacks[command](*parts[1:]) - except: - import traceback - parser = lambda s: s if len(s)<20 else s[:20] + '...' - parsed_parts = ' '.join([parser(k) for k in parts]) - message = "A problem occurred while performing command `%s' " \ - "killing user process. Exception:\n %s" % \ - (parsed_parts, traceback.format_exc()) - logger.error(message, exc_info=True) - self.system_error = message - if self.process is not None: - self.process.kill() - self.stop.set() - break - - else: - message = "Command `%s' is not implemented - stopping user process" \ - % command - logger.error(message) - self.system_error = message - if self.process is not None: - self.process.kill() - self.stop.set() - break - - self.socket.setsockopt(zmq.LINGER, 0) - self.socket.close() - logger.debug("0MQ server thread stopped") - - - def _get_input_candidate(self, channel, name): - - channel_group = self.input_list.group(channel) - retval = channel_group[name] - if retval is None: - raise RuntimeError("Could not find input `%s' at channel `%s'" % \ - (name, channel)) - return retval - - - def next(self, channel, name=None): - """Syntax: nxt channel [name] ...""" - - if name is not None: #single input - logger.debug('recv: nxt %s %s', channel, name) - - input_candidate = self._get_input_candidate(channel, name) - input_candidate.next() - if input_candidate.data is None: #error - message = "User algorithm asked for more data for channel " \ - "`%s' on input `%s', but it is over (no more data). This " \ - "normally indicates a programming error on the user " \ - "side." % (channel, name) - self.user_error += message + '\n' - raise RuntimeError(message) - if isinstance(input_candidate.data, baseformat.baseformat): - packed = input_candidate.data.pack() - else: - packed = input_candidate.data - logger.debug('send: <bin> (size=%d)', len(packed)) - self.socket.send(packed) - - else: #whole group data - logger.debug('recv: nxt %s', channel) - - channel_group = self.input_list.group(channel) - - # Call next() on the group - channel_group.restricted_access = False - channel_group.next() - channel_group.restricted_access = True - - # Loop over the inputs - inputs_to_go = len(channel_group) - self.socket.send(str(inputs_to_go), zmq.SNDMORE) - for inp in channel_group: - logger.debug('send: %s', inp.name) - self.socket.send(str(inp.name), zmq.SNDMORE) - if inp.data is None: - message = "User algorithm process asked for more data on channel " \ - "`%s' (all inputs), but input `%s' has nothing. This " \ - "normally indicates a programming error on the user " \ - "side." % (channel, inp.name) - self.user_error += message + '\n' - raise RuntimeError(message) - elif isinstance(inp.data, baseformat.baseformat): - packed = inp.data.pack() - else: - packed = inp.data - logger.debug('send: <bin> (size=%d)', len(packed)) - inputs_to_go -= 1 - if inputs_to_go > 0: - self.socket.send(packed, zmq.SNDMORE) - else: - self.socket.send(packed) - - - def has_more_data(self, channel, name=None): - """Syntax: hmd channel [name]""" - - if name: #single input - logger.debug('recv: hmd %s %s', channel, name) - input_candidate = self._get_input_candidate(channel, name) - what = 'tru' if input_candidate.hasMoreData() else 'fal' - - else: #for all channel names - logger.debug('recv: hmd %s', channel) - channel_group = self.input_list.group(channel) - what = 'tru' if channel_group.hasMoreData() else 'fal' - - logger.debug('send: %s', what) - self.socket.send(what) - - - def is_dataunit_done(self, channel, name): - """Syntax: idd channel name""" - - logger.debug('recv: idd %s %s', channel, name) - input_candidate = self._get_input_candidate(channel, name) - what = 'tru' if input_candidate.isDataUnitDone() else 'fal' - logger.debug('send: %s', what) - self.socket.send(what) - - - def _collect_statistics(self): - - logger.debug('collecting user process statistics...') - if self.process is not None: - self.last_statistics = self.process.statistics() - - - def _acknowledge(self): - - logger.debug('send: ack') - self.socket.send('ack') - logger.debug('setting stop condition for 0MQ server thread') - self.stop.set() - - - def done(self, wait_time=None): - """Syntax: don""" - - logger.debug('recv: don %s', wait_time) - - if wait_time is not None: - self._collect_statistics() - - # collect I/O stats from client - wait_time = float(wait_time) - self.last_statistics['data'] = dict(network=dict(wait_time=wait_time)) - - self._acknowledge() - - - def error(self, t, msg): - """Syntax: err type message""" - - logger.debug('recv: err %s <msg> (size=%d)', t, len(msg)) - if t == 'usr': self.user_error = msg - else: self.system_error = msg - - self._collect_statistics() - self.last_statistics['data'] = dict(network=dict(wait_time=0.)) - self._acknowledge() - - - def kill(self): - self.must_kill.set() - +from beat.backend.python.message_handler import MessageHandler class Server(MessageHandler): @@ -398,7 +154,9 @@ class Agent(object): self.virtual_memory_in_megabytes = virtual_memory_in_megabytes self.max_cpu_percent = max_cpu_percent self.tempdir = None + self.db_tempdir = None self.process = None + self.db_process = None self.server = None @@ -410,7 +168,10 @@ class Agent(object): # Creates a temporary directory for the user process self.tempdir = utils.temporary_directory() logger.debug("Created temporary directory `%s'", self.tempdir) + self.db_tempdir = utils.temporary_directory() + logger.debug("Created temporary directory `%s'", self.db_tempdir) self.process = None + self.db_process = None return self @@ -420,11 +181,16 @@ class Agent(object): shutil.rmtree(self.tempdir) self.tempdir = None + if self.db_tempdir is not None and os.path.exists(self.db_tempdir): + shutil.rmtree(self.db_tempdir) + self.db_tempdir = None + self.process = None + self.db_process = None logger.debug("Exiting processing context...") - def run(self, configuration, host, timeout_in_minutes=0, daemon=0): + def run(self, configuration, host, timeout_in_minutes=0, daemon=0, db_address=None): """Runs the algorithm code @@ -452,6 +218,9 @@ class Agent(object): # Recursively copies configuration data to <tempdir>/prefix configuration.dump_runner_configuration(self.tempdir) + if db_address is not None: + configuration.dump_databases_provider_configuration(self.db_tempdir) + # Server for our single client self.server = Server(configuration.input_list, configuration.output_list, host.ip) @@ -475,14 +244,27 @@ class Agent(object): cmd = ['sleep', str(daemon)] logger.debug("Daemon mode: sleeping for %d seconds", daemon) else: + if db_address is not None: + tmp_dir = os.path.join('/tmp', os.path.basename(self.db_tempdir)) + # db_cmd = ['bash', '-c', 'source activate beat_env; databases_provider %s %s' % (db_address, tmp_dir)] + db_cmd = ['bash', '-c', 'databases_provider %s %s' % (db_address, tmp_dir)] + + self.db_process = dock.Popen( + host, + 'docker.idiap.ch/beat/beat.env.db.examples:1.0.0', + # 'docker.idiap.ch/beat/beat.env.db:1.0.0', + command=db_cmd, + tmp_archive=self.db_tempdir, + ) + self.process = dock.Popen( - host, - envkey, - command=cmd, - tmp_archive=self.tempdir, - virtual_memory_in_megabytes=self.virtual_memory_in_megabytes, - max_cpu_percent=self.max_cpu_percent, - ) + host, + envkey, + command=cmd, + tmp_archive=self.tempdir, + virtual_memory_in_megabytes=self.virtual_memory_in_megabytes, + max_cpu_percent=self.max_cpu_percent, + ) # provide a tip on how to stop the test if daemon > 0: @@ -504,6 +286,11 @@ class Agent(object): timeout_in_minutes) self.process.kill() status = self.process.wait() + + if self.db_process is not None: + self.db_process.kill() + self.db_process.wait() + timed_out = True except KeyboardInterrupt: #developer pushed CTRL-C @@ -511,6 +298,10 @@ class Agent(object): self.process.kill() status = self.process.wait() + if self.db_process is not None: + self.db_process.kill() + self.db_process.wait() + finally: self.server.stop.set() @@ -527,6 +318,13 @@ class Agent(object): user_error = self.server.user_error, ) process.rm() + + if self.db_process is not None: + retval['stdout'] += '\n' + self.db_process.stdout + retval['stderr'] += '\n' + self.db_process.stderr + self.db_process.rm() + self.db_process = None + self.server = None return retval diff --git a/beat/core/database.py b/beat/core/database.py index 550ed8cace765cc3e9e58c60c369e4b19238c381..9b52b5f3b0f8d86889269570a25ef1568067b6c3 100644 --- a/beat/core/database.py +++ b/beat/core/database.py @@ -43,162 +43,13 @@ from . import hash from . import utils from . import prototypes -class Storage(utils.CodeStorage): - """Resolves paths for databases - - Parameters: - - prefix (str): Establishes the prefix of your installation. - - name (str): The name of the database object in the format - ``<name>/<version>``. - - """ - - def __init__(self, prefix, name): - - if name.count('/') != 1: - raise RuntimeError("invalid database name: `%s'" % name) - - self.name, self.version = name.split('/') - self.fullname = name - - path = os.path.join(prefix, 'databases', name) - super(Storage, self).__init__(path, 'python') #views are coded in Python - - -class View(object): - '''A special loader class for database views, with specialized methods - - Parameters: - - db_name (str): The full name of the database object for this view - - module (module): The preloaded module containing the database views as - returned by :py:func:`beat.core.loader.load_module`. - - prefix (str, path): The prefix path for the current installation - - root_folder (str, path): The path pointing to the root folder of this - database - - exc (class): The class to use as base exception when translating the - exception from the user code. Read the documention of :py:func:`run` - for more details. - - *args: Constructor parameters for the database view. Normally, none. - - **kwargs: Constructor parameters for the database view. Normally, none. - - ''' - - - def __init__(self, module, definition, prefix, root_folder, exc=None, - *args, **kwargs): - - try: - class_ = getattr(module, definition['view']) - except Exception as e: - if exc is not None: - type, value, traceback = sys.exc_info() - six.reraise(exc, exc(value), traceback) - else: - raise #just re-raise the user exception - - self.obj = loader.run(class_, '__new__', exc, *args, **kwargs) - self.ready = False - self.prefix = prefix - self.root_folder = root_folder - self.definition = definition - self.exc = exc or RuntimeError - self.outputs = None - - - def prepare_outputs(self): - '''Prepares the outputs of the dataset''' - - from .outputs import Output, OutputList - from .data import MemoryDataSink - from .dataformat import DataFormat - - # create the stock outputs for this dataset, so data is dumped - # on a in-memory sink - self.outputs = OutputList() - for out_name, out_format in self.definition.get('outputs', {}).items(): - data_sink = MemoryDataSink() - data_sink.dataformat = DataFormat(self.prefix, out_format) - data_sink.setup([]) - self.outputs.add(Output(out_name, data_sink, dataset_output=True)) - - - def setup(self, *args, **kwargs): - '''Sets up the view''' - - kwargs.setdefault('root_folder', self.root_folder) - kwargs.setdefault('parameters', self.definition.get('parameters', {})) - - if 'outputs' not in kwargs: - kwargs['outputs'] = self.outputs - else: - self.outputs = kwargs['outputs'] #record outputs nevertheless - - self.ready = loader.run(self.obj, 'setup', self.exc, *args, **kwargs) - - if not self.ready: - raise self.exc("unknow setup failure") - - return self.ready - - - def input_group(self, name='default', exclude_outputs=[]): - '''A memory-source input group matching the outputs from the view''' - - if not self.ready: - raise self.exc("database view not yet setup") - - from .data import MemoryDataSource - from .outputs import SynchronizationListener - from .inputs import Input, InputGroup - - # Setup the inputs - synchronization_listener = SynchronizationListener() - input_group = InputGroup(name, - synchronization_listener=synchronization_listener, - restricted_access=False) - - for output in self.outputs: - if output.name in exclude_outputs: continue - data_source = MemoryDataSource(self.done, next_callback=self.next) - output.data_sink.data_sources.append(data_source) - input_group.add(Input(output.name, - output.data_sink.dataformat, data_source)) - - return input_group - - - def done(self, *args, **kwargs): - '''Checks if the view is done''' - - if not self.ready: - raise self.exc("database view not yet setup") - - return loader.run(self.obj, 'done', self.exc, *args, **kwargs) +from beat.backend.python.database import Storage +from beat.backend.python.database import View +from beat.backend.python.database import Database as BackendDatabase - def next(self, *args, **kwargs): - '''Runs through the next data chunk''' - if not self.ready: - raise self.exc("database view not yet setup") - return loader.run(self.obj, 'next', self.exc, *args, **kwargs) - - - def __getattr__(self, key): - '''Returns an attribute of the view - only called at last resort''' - return getattr(self.obj, key) - - -class Database(object): +class Database(BackendDatabase): """Databases define the start point of the dataflow in an experiment. @@ -240,20 +91,8 @@ class Database(object): """ def __init__(self, prefix, data, dataformat_cache=None): + super(Database, self).__init__(prefix, data, dataformat_cache) - self._name = None - self.storage = None - self.prefix = prefix - self.dataformats = {} # preloaded dataformats - - self.errors = [] - self.data = None - self.code = None - - # if the user has not provided a cache, still use one for performance - dataformat_cache = dataformat_cache if dataformat_cache is not None else {} - - self._load(data, dataformat_cache) def _load(self, data, dataformat_cache): """Loads the database""" @@ -353,111 +192,20 @@ class Database(object): "unsupported by this version" % (_set['view'],) ) + @property def name(self): """Returns the name of this object """ return self._name or '__unnamed_database__' + @name.setter def name(self, value): self._name = value self.storage = Storage(self.prefix, value) - @property - def schema_version(self): - """Returns the schema version""" - return self.data.get('schema_version', 1) - - - @property - def protocols(self): - """The declaration of all the protocols of the database""" - - data = self.data['protocols'] - return dict(zip([k['name'] for k in data], data)) - - def protocol(self, name): - """The declaration of a specific protocol in the database""" - - return self.protocols[name] - - @property - def protocol_names(self): - """Names of protocols declared for this database""" - - data = self.data['protocols'] - return [k['name'] for k in data] - - def sets(self, protocol): - """The declaration of a specific set in the database protocol""" - - data = self.protocol(protocol)['sets'] - return dict(zip([k['name'] for k in data], data)) - - def set(self, protocol, name): - """The declaration of all the protocols of the database""" - - return self.sets(protocol)[name] - - def set_names(self, protocol): - """The names of sets in a given protocol for this database""" - - data = self.protocol(protocol)['sets'] - return [k['name'] for k in data] - - @property - def valid(self): - return not bool(self.errors) - - - def view(self, protocol, name, exc=None): - """Returns the database view, given the protocol and the set name - - Parameters: - - protocol (str): The name of the protocol where to retrieve the view from - - name (str): The name of the set in the protocol where to retrieve the - view from - - exc (class): If passed, must be a valid exception class that will be - used to report errors in the read-out of this database's view. - - Returns: - - The database view, which will be constructed, but not setup. You - **must** set it up before using methods ``done`` or ``next``. - - """ - - if not self._name: - exc = exc or RuntimeError - raise exc("database has no name") - - if not self.valid: - message = "cannot load view for set `%s' of protocol `%s' " \ - "from invalid database (%s)" % (protocol, name, self.name) - if exc: raise exc(message) - raise RuntimeError(message) - - # loads the module only once through the lifetime of the database object - try: - if not hasattr(self, '_module'): - self._module = loader.load_module(self.name.replace(os.sep, '_'), - self.storage.code.path, {}) - except Exception as e: - if exc is not None: - type, value, traceback = sys.exc_info() - six.reraise(exc, exc(value), traceback) - else: - raise #just re-raise the user exception - - return View(self._module, self.set(protocol, name), self.prefix, - self.data['root_folder'], exc) - - def hash_output(self, protocol, set, output): """Creates a unique hash the represents the output from the dataset @@ -497,6 +245,7 @@ class Database(object): """The short description for this object""" return self.data.get('description', None) + @description.setter def description(self, value): """Sets the short description for this object""" diff --git a/beat/core/execution.py b/beat/core/execution.py index 9636f2b1ac578d7a01ae31140f5d3d59a63fd780..26690b8eb3013d62cd6265a91d7525f69e357f1b 100644 --- a/beat/core/execution.py +++ b/beat/core/execution.py @@ -34,6 +34,7 @@ import glob import errno import tempfile import subprocess +import zmq.green as zmq import logging logger = logging.getLogger(__name__) @@ -48,6 +49,7 @@ from . import outputs from . import data from . import stats from . import agent +from . import dock class Executor(object): @@ -181,6 +183,7 @@ class Executor(object): self.output_list = None self.data_sinks = [] self.data_sources = [] + self.db_address = None if not isinstance(data, dict): #user has passed a file pointer if not os.path.exists(data): @@ -221,24 +224,6 @@ class Executor(object): if not db.valid: self.errors += db.errors - continue - - if not db.valid: - # do not add errors again - continue - - # create and load the required views - key = (details['database'], details['protocol'], details['set']) - if key not in self.views: - view = self.databases[details['database']].view(details['protocol'], - details['set']) - - if details['channel'] == self.data['channel']: #synchronized - start_index, end_index = self.data.get('range', (None, None)) - else: - start_index, end_index = (None, None) - view.prepare_outputs() - self.views[key] = (view, start_index, end_index) def __enter__(self): @@ -250,23 +235,17 @@ class Executor(object): """ + if len(self.databases) > 0: + host = dock.Host() + self.context = zmq.Context() + self.db_socket = self.context.socket(zmq.PAIR) + self.db_address = 'tcp://' + host.ip + port = self.db_socket.bind_to_random_port(self.db_address) + self.db_address += ':%d' % port + self._prepare_inputs() self._prepare_outputs() - # The setup() of a database view may call isConnected() on an input - # to set the index at the right location when parallelization is enabled. - # This is why setup() should be called after initialized the inputs. - for key, (view, start_index, end_index) in self.views.items(): - - if (start_index is None) and (end_index is None): - status = view.setup() - else: - status = view.setup(force_start_index=start_index, - force_end_index=end_index) - - if not status: - raise RuntimeError("Could not setup database view `%s'" % key) - self.agent = None return self @@ -302,19 +281,23 @@ class Executor(object): if 'database' in details: #it is a dataset input - view_key = (details['database'], details['protocol'], details['set']) - view = self.views[view_key][0] + # create the remote input + db = self.databases[details['database']] - data_source = data.MemoryDataSource(view.done, next_callback=view.next) - self.data_sources.append(data_source) - output = view.outputs[details['output']] + dataformat_name = db.set(details['protocol'], details['set'])['outputs'][details['output']] + input = inputs.RemoteInput(name, db.dataformats[dataformat_name], self.db_socket) - # if it's a synchronized channel, makes the output start at the right - # index, otherwise, it gets lost - if start_index is not None and \ - details['channel'] == self.data['channel']: - output.last_written_data_index = start_index - 1 - output.data_sink.data_sources.append(data_source) + # Synchronization bits + group = self.input_list.group(details['channel']) + if group is None: + group = inputs.RemoteInputGroup( + details['channel'], + restricted_access=(details['channel'] == self.data['channel']), + socket=self.db_socket + ) + self.input_list.add(group) + + group.add(input) else: @@ -322,30 +305,33 @@ class Executor(object): self.data_sources.append(data_source) if details['channel'] == self.data['channel']: #synchronized status = data_source.setup( - filename=os.path.join(self.cache, details['path'] + '.data'), - prefix=self.prefix, - force_start_index=start_index, - force_end_index=end_index, + filename=os.path.join(self.cache, details['path'] + '.data'), + prefix=self.prefix, + force_start_index=start_index, + force_end_index=end_index, ) else: status = data_source.setup( - filename=os.path.join(self.cache, details['path'] + '.data'), - prefix=self.prefix, + filename=os.path.join(self.cache, details['path'] + '.data'), + prefix=self.prefix, ) if not status: raise IOError("cannot load cache file `%s'" % details['path']) - # Synchronization bits - group = self.input_list.group(details['channel']) - if group is None: - group = inputs.InputGroup( - details['channel'], - synchronization_listener=outputs.SynchronizationListener(), - restricted_access=(details['channel'] == self.data['channel']) - ) - self.input_list.add(group) - group.add(inputs.Input(name, self.algorithm.input_map[name], data_source)) + input = inputs.Input(name, self.algorithm.input_map[name], data_source) + + # Synchronization bits + group = self.input_list.group(details['channel']) + if group is None: + group = inputs.InputGroup( + details['channel'], + synchronization_listener=outputs.SynchronizationListener(), + restricted_access=(details['channel'] == self.data['channel']) + ) + self.input_list.add(group) + + group.add(input) def _prepare_outputs(self): @@ -383,7 +369,7 @@ class Executor(object): raise IOError("cannot create cache sink `%s'" % details['path']) input_group = self.input_list.group(details['channel']) - if input_group is None: + if (input_group is None) or not hasattr(input_group, 'synchronization_listener'): synchronization_listener = None else: synchronization_listener = input_group.synchronization_listener @@ -490,7 +476,7 @@ class Executor(object): #synchronous call - always returns after a certain timeout retval = runner.run(self, host, timeout_in_minutes=timeout_in_minutes, - daemon=daemon) + daemon=daemon, db_address=self.db_address) #adds I/O statistics from the current executor, if its complete already #otherwise, it means the running process went bananas, ignore it ;-) @@ -579,7 +565,7 @@ class Executor(object): data['channel'] = self.data['channel'] with open(os.path.join(directory, 'configuration.json'), 'wb') as f: - simplejson.dump(data, f, indent=2) + simplejson.dump(data, f, indent=2) tmp_prefix = os.path.join(directory, 'prefix') if not os.path.exists(tmp_prefix): os.makedirs(tmp_prefix) @@ -587,6 +573,19 @@ class Executor(object): self.algorithm.export(tmp_prefix) + def dump_databases_provider_configuration(self, directory): + """Exports contents useful for a backend runner to run the algorithm""" + + with open(os.path.join(directory, 'configuration.json'), 'wb') as f: + simplejson.dump(self.data, f, indent=2) + + tmp_prefix = os.path.join(directory, 'prefix') + if not os.path.exists(tmp_prefix): os.makedirs(tmp_prefix) + + for db in self.databases.values(): + db.export(tmp_prefix) + + def kill(self): """Stops the user process by force - to be called from signal handlers""" diff --git a/beat/core/hash.py b/beat/core/hash.py index 40b074f302334280d13c83aeaede030d88db587e..6d06a866142de22e6a086b577634f7cf2e4a3b03 100644 --- a/beat/core/hash.py +++ b/beat/core/hash.py @@ -31,48 +31,20 @@ import os import six -import copy import hashlib import collections import simplejson - -def _sha256(s): - """A python2/3 replacement for :py:func:`haslib.sha256`""" - - try: - if isinstance(s, str): s = six.u(s) - return hashlib.sha256(s.encode('utf8')).hexdigest() - except: - return hashlib.sha256(s).hexdigest() +from beat.backend.python.hash import * +from beat.backend.python.hash import _sha256 +from beat.backend.python.hash import _stringify def _compact(text): return text.replace(' ', '').replace('\n', '') -def _stringify(dictionary): - names = sorted(dictionary.keys()) - - converted_dictionary = '{' - for name in names: - converted_dictionary += '"%s":%s,' % (name, str(dictionary[name])) - - if len(converted_dictionary) > 1: - converted_dictionary = converted_dictionary[:-1] - - converted_dictionary += '}' - - return converted_dictionary - - -def hash(dictionary_or_string): - if isinstance(dictionary_or_string, dict): - return _sha256(_stringify(dictionary_or_string)) - else: - return _sha256(dictionary_or_string) - def hashDatasetOutput(database_hash, protocol_name, set_name, output_name): s = _compact("""{ @@ -84,6 +56,7 @@ def hashDatasetOutput(database_hash, protocol_name, set_name, output_name): return hash(s) + def hashBlockOutput(block_name, algorithm_name, algorithm_hash, parameters, environment, input_hashes, output_name): # Note: 'block_name' and 'algorithm_name' aren't used to compute the hash, @@ -100,6 +73,7 @@ def hashBlockOutput(block_name, algorithm_name, algorithm_hash, return hash(s) + def hashAnalyzer(analyzer_name, algorithm_name, algorithm_hash, parameters, environment, input_hashes): # Note: 'analyzer_name' isn't used to compute the hash, but are useful when @@ -115,26 +89,17 @@ def hashAnalyzer(analyzer_name, algorithm_name, algorithm_hash, return hash(s) + def toPath(hash, suffix='.data'): return os.path.join(hash[0:2], hash[2:4], hash[4:6], hash[6:] + suffix) + def toUserPath(username): hash = _sha256(username) return os.path.join(hash[0:2], hash[2:4], username) -def hashJSON(contents, description): - """Hashes the pre-loaded JSON object using :py:func:`hashlib.sha256` - - Excludes description changes - """ - - if description in contents: - contents = copy.deepcopy(contents) #temporary copy - del contents[description] - contents = simplejson.dumps(contents, sort_keys=True) - return hashlib.sha256(contents).hexdigest() def hashJSONStr(contents, description): """Hashes the JSON string contents using :py:func:`hashlib.sha256` @@ -148,23 +113,3 @@ def hashJSONStr(contents, description): except simplejson.JSONDecodeError: # falls back to normal file content hashing return hash(contents) - -def hashJSONFile(path, description): - """Hashes the JSON file contents using :py:func:`hashlib.sha256` - - Excludes description changes - """ - - try: - with open(path, 'rb') as f: - return hashJSON(simplejson.load(f, - object_pairs_hook=collections.OrderedDict), description) #preserve order - except simplejson.JSONDecodeError: - # falls back to normal file content hashing - return hashFileContents(path) - -def hashFileContents(path): - """Hashes the file contents using :py:func:`hashlib.sha256`.""" - - with open(path, 'rb') as f: - return hashlib.sha256(f.read()).hexdigest() diff --git a/beat/core/inputs.py b/beat/core/inputs.py index 24437370312ee859c5da22c99837873c5d99e415..50b0d118c93c5834eb4dda0fe64a4f5d2bbb1791 100644 --- a/beat/core/inputs.py +++ b/beat/core/inputs.py @@ -26,212 +26,8 @@ ############################################################################### -from functools import reduce - -import six - from beat.backend.python.inputs import InputList -from beat.backend.python.inputs import Input as RemoteInput -from beat.backend.python.inputs import InputGroup as RemoteInputGroup - - -class Input: - """Represents the input of a processing block - - A list of those inputs must be provided to the algorithms (see - :py:class:`beat.backend.python.inputs.InputList`) - - - Parameters: - - name (str): Name of the input - - data_format (str): Data format accepted by the input - - data_source (beat.core.platform.data.DataSource): Source of data to be used - by the input - - - Attributes: - - group (beat.core.inputs.InputGroup): Group containing this input - - name (str): Name of the input (algorithm-specific) - - data (beat.core.baseformat.baseformat): The last block of data received on - the input - - data_index (int): Index of the last block of data received on the input - (see the section *Inputs synchronization* of the User's Guide) - - data_index_end (int): End index of the last block of data received on the - input (see the section *Inputs synchronization* of the User's Guide) - - data_format (str): Data format accepted by the input - - data_source (beat.core.data.DataSource): Source of data used by the output - - nb_data_blocks_read (int): Number of data blocks read so far - - """ - - def __init__(self, name, data_format, data_source): - - self.group = None - self.name = name - self.data = None - self.data_index = -1 - self.data_index_end = -1 - self.data_same_as_previous = False - self.data_format = data_format - self.data_source = data_source - self.nb_data_blocks_read = 0 - - def isDataUnitDone(self): - """Indicates if the current data unit will change at the next iteration""" - - return (self.data_index_end == self.group.data_index_end) - - def hasMoreData(self): - """Indicates if there is more data to process on the input""" - - return self.data_source.hasMoreData() - - def next(self): - """Retrieves the next block of data""" - - (self.data, self.data_index, self.data_index_end) = self.data_source.next() - self.data_same_as_previous = False - self.nb_data_blocks_read += 1 - - -class InputGroup: - """Represents a group of inputs synchronized together - - A group implementing this interface is provided to the algorithms (see - :py:class:`beat.backend.python.inputs.InputList`). - - See :py:class:`beat.core.inputs.Input` - - Example: - - .. code-block:: python - - inputs = InputList() - - print inputs['labels'].data_format - - for index in range(0, len(inputs)): - print inputs[index].data_format - - for input in inputs: - print input.data_format - - for input in inputs[0:2]: - print input.data_format - - - Parameters: - - channel (str): Name of the data channel of the group - - synchronization_listener (beat.core.outputs.SynchronizationListener): - Synchronization listener to use - - restricted_access (bool): Indicates if the algorithm can freely use the - inputs - - - Atttributes: - - data_index (int): Index of the last block of data received on the inputs - (see the section *Inputs synchronization* of the User's Guide) - - data_index_end (int): End index of the last block of data received on the - inputs (see the section *Inputs synchronization* of the User's Guide) - - channel (str): Name of the data channel of the group - - synchronization_listener (beat.core.outputs.SynchronizationListener): - Synchronization listener used - - """ - - def __init__(self, channel, synchronization_listener=None, - restricted_access=True): - - self._inputs = [] - self.data_index = -1 - self.data_index_end = -1 - self.channel = channel - self.synchronization_listener = synchronization_listener - self.restricted_access = restricted_access - - def __getitem__(self, index): - - if isinstance(index, six.string_types): - try: - return [x for x in self._inputs if x.name == index][0] - except: - pass - elif isinstance(index, int): - if index < len(self._inputs): - return self._inputs[index] - return None - - def __iter__(self): - - for k in self._inputs: yield k - - def __len__(self): - - return len(self._inputs) - - def add(self, input): - """Add an input to the group - - Parameters: - - input (beat.core.inputs.Input): The input to add - - """ - - input.group = self - self._inputs.append(input) - - def hasMoreData(self): - """Indicates if there is more data to process in the group""" - - return bool([x for x in self._inputs if x.hasMoreData()]) - - def next(self): - """Retrieve the next block of data on all the inputs""" - - # Only for groups not managed by the platform - if self.restricted_access: raise RuntimeError('Not authorized') - - # Only retrieve new data on the inputs where the current data expire first - lower_end_index = reduce(lambda x, y: min(x, y.data_index_end), - self._inputs[1:], self._inputs[0].data_index_end) - inputs_to_update = [x for x in self._inputs \ - if x.data_index_end == lower_end_index] - inputs_up_to_date = [x for x in self._inputs if x not in inputs_to_update] - - for input in inputs_to_update: - input.next() - input.data_same_as_previous = False - - for input in inputs_up_to_date: - input.data_same_as_previous = True - - - # Compute the group's start and end indices - self.data_index = reduce(lambda x, y: max(x, y.data_index), - self._inputs[1:], self._inputs[0].data_index) - self.data_index_end = reduce(lambda x, y: min(x, y.data_index_end), - self._inputs[1:], self._inputs[0].data_index_end) - - # Inform the synchronisation listener - if self.synchronization_listener is not None: - self.synchronization_listener.onIntervalChanged(self.data_index, - self.data_index_end) +from beat.backend.python.inputs import Input +from beat.backend.python.inputs import InputGroup +from beat.backend.python.inputs import RemoteInput +from beat.backend.python.inputs import RemoteInputGroup diff --git a/beat/core/outputs.py b/beat/core/outputs.py index adfb4416b1750b373986dab568f860e889733418..3f4762edff0f3441330c4c2e5dbb51811ff776bb 100644 --- a/beat/core/outputs.py +++ b/beat/core/outputs.py @@ -26,130 +26,7 @@ ############################################################################### +from beat.backend.python.outputs import SynchronizationListener +from beat.backend.python.outputs import Output +from beat.backend.python.outputs import RemoteOutput from beat.backend.python.outputs import OutputList - - -class SynchronizationListener: - """A callback mechanism to keep Inputs and Outputs in groups and lists - synchronized together.""" - - def __init__(self): - self.data_index_start = -1 - self.data_index_end = -1 - - def onIntervalChanged(self, data_index_start, data_index_end): - self.data_index_start = data_index_start - self.data_index_end = data_index_end - - -class Output: - """Represents one output of a processing block - - A list of outputs implementing this interface is provided to the algorithms - (see :py:class:`beat.core.outputs.OutputList`). - - - Parameters: - - name (str): Name of the output - - data_sink (beat.core.data.DataSink): Sink of data to be used by the output, - pre-configured with the correct data format. - - - Attributes: - - name (str): Name of the output (algorithm-specific) - - data_sink (beat.core.data.DataSink): Sink of data used by the output - - last_written_data_index (int): Index of the last block of data written by - the output - - nb_data_blocks_written (int): Number of data blocks written so far - - - """ - - def __init__(self, name, data_sink, synchronization_listener=None, - dataset_output=False, force_start_index=0): - - self.name = name - self.data_sink = data_sink - self._synchronization_listener = synchronization_listener - self._dataset_output = dataset_output - self.last_written_data_index = force_start_index-1 - self.nb_data_blocks_written = 0 - - - def _createData(self): - """Retrieves an uninitialized block of data corresponding to the data - format of the output - - This method must be called to correctly create a new block of data - """ - - if hasattr(self.data_sink, 'dataformat'): - return self.data_sink.dataformat.type() - else: - raise RuntimeError("The currently used data sink is not bound to " \ - "a dataformat - you cannot create uninitialized data under " \ - "these circumstances") - - - def write(self, data, end_data_index=None): - """Write a block of data on the output - - Parameters: - - data (beat.core.baseformat.baseformat): The block of data to write, or - None (if the algorithm doesn't want to write any data) - - end_data_index (int): Last index of the written data (see the section - *Inputs synchronization* of the User's Guide). If not specified, the - *current end data index* of the Inputs List is used - - """ - - if self._dataset_output: - if end_data_index is None: - end_data_index = self.last_written_data_index + 1 - elif end_data_index < self.last_written_data_index + 1: - raise KeyError("Database wants to write an `end_data_index' (%d) " \ - "which is smaller than the last written index (%d) " \ - "+1 - this is a database bug - Fix it!" % \ - (end_data_index, self.last_written_data_index)) - - elif end_data_index is not None: - if (end_data_index < self.last_written_data_index + 1) or \ - ((self._synchronization_listener is not None) and \ - (end_data_index > self._synchronization_listener.data_index_end)): - raise KeyError("Algorithm logic error on write(): `end_data_index' " \ - "is not consistent with last written index") - - elif self._synchronization_listener is not None: - end_data_index = self._synchronization_listener.data_index_end - - else: - end_data_index = self.last_written_data_index + 1 - - # if the user passes a dictionary, converts to the proper baseformat type - if isinstance(data, dict): - d = self.data_sink.dataformat.type() - d.from_dict(data, casting='safe', add_defaults=False) - data = d - - self.data_sink.write(data, self.last_written_data_index + 1, end_data_index) - - self.last_written_data_index = end_data_index - self.nb_data_blocks_written += 1 - - def isDataMissing(self): - - return not(self._dataset_output) and \ - (self._synchronization_listener is not None) and \ - (self._synchronization_listener.data_index_end != self.last_written_data_index) - - def isConnected(self): - - return self.data_sink.isConnected() diff --git a/beat/core/test/prefix/experiments/user/user/double/1/cxx_double.json b/beat/core/test/prefix/experiments/user/user/double/1/cxx_double.json index 1f3750601973152ab3c47a5f3a68be67d83efad1..9f9c2c8b37047b90f97e16ec4fc050cc07ff6f85 100644 --- a/beat/core/test/prefix/experiments/user/user/double/1/cxx_double.json +++ b/beat/core/test/prefix/experiments/user/user/double/1/cxx_double.json @@ -17,8 +17,8 @@ "out_data": "out" }, "environment": { - "name": "cxx_environment", - "version": "1" + "name": "Cxx backend", + "version": "1.0.0" } }, "echo2": { @@ -30,8 +30,8 @@ "out_data": "out" }, "environment": { - "name": "cxx_environment", - "version": "1" + "name": "Cxx backend", + "version": "1.0.0" } } }, diff --git a/beat/core/test/test_message_handler.py b/beat/core/test/test_message_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..babeae4fcfeb95f75c0f112c737de3a9de565a86 --- /dev/null +++ b/beat/core/test/test_message_handler.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +############################################################################### +# # +# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ # +# Contact: beat.support@idiap.ch # +# # +# This file is part of the beat.core module of the BEAT platform. # +# # +# Commercial License Usage # +# Licensees holding valid commercial BEAT licenses may use this file in # +# accordance with the terms contained in a written agreement between you # +# and Idiap. For further information contact tto@idiap.ch # +# # +# Alternatively, this file may be used under the terms of the GNU Affero # +# Public License version 3 as published by the Free Software and appearing # +# in the file LICENSE.AGPL included in the packaging of this file. # +# The BEAT platform is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # +# or FITNESS FOR A PARTICULAR PURPOSE. # +# # +# You should have received a copy of the GNU Affero Public License along # +# with the BEAT platform. If not, see http://www.gnu.org/licenses/. # +# # +############################################################################### + + +# Tests for experiment execution + +import os +import logging +logger = logging.getLogger(__name__) + +# in case you want to see the printouts dynamically, set to ``True`` +if False: + logger = logging.getLogger() #root logger + logger.setLevel(logging.DEBUG) + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + ch.setFormatter(logging.Formatter('%(levelname)s: %(message)s')) + logger.addHandler(ch) + +import unittest +import zmq.green as zmq +import nose.tools + +from ..agent import MessageHandler +from ..dataformat import DataFormat +from ..inputs import RemoteInput +from ..inputs import RemoteInputGroup +from ..inputs import Input +from ..inputs import InputGroup +from ..inputs import InputList + +from .mocks import MockDataSource + +from . import prefix + + + +class TestMessageHandler(unittest.TestCase): + + def setUp(self): + dataformat = DataFormat(prefix, 'user/single_integer/1') + + data_source_a = MockDataSource([ + dataformat.type(value=10), + dataformat.type(value=20), + ], + [ + (0, 0), + (1, 1), + ] + ) + + input_a = Input('a', 'user/single_integer/1', data_source_a) + + data_source_b = MockDataSource([ + dataformat.type(value=100), + dataformat.type(value=200), + ], + [ + (0, 0), + (1, 1), + ] + ) + + input_b = Input('b', 'user/single_integer/1', data_source_b) + + group = InputGroup('channel') + group.add(input_a) + group.add(input_b) + + self.input_list = InputList() + self.input_list.add(group) + + + self.server_context = zmq.Context() + server_socket = self.server_context.socket(zmq.PAIR) + address = 'tcp://127.0.0.1' + port = server_socket.bind_to_random_port(address) + address += ':%d' % port + + self.message_handler = MessageHandler(self.input_list, self.server_context, server_socket) + + + self.client_context = zmq.Context() + client_socket = self.client_context.socket(zmq.PAIR) + client_socket.connect(address) + + self.remote_input_a = RemoteInput('a', dataformat, client_socket) + self.remote_input_b = RemoteInput('b', dataformat, client_socket) + + self.remote_group = RemoteInputGroup('channel', False, client_socket) + self.remote_group.add(self.remote_input_a) + self.remote_group.add(self.remote_input_b) + + self.remote_input_list = InputList() + self.remote_input_list.add(self.remote_group) + + self.message_handler.start() + + + def test_input_has_more_data(self): + assert self.remote_input_a.hasMoreData() + + + def test_input_next(self): + self.remote_input_a.next() + nose.tools.eq_(self.remote_input_a.data.value, 10) + + + def test_input_full_cycle(self): + assert self.remote_input_a.hasMoreData() + self.remote_input_a.next() + nose.tools.eq_(self.remote_input_a.data.value, 10) + + assert self.remote_input_a.hasMoreData() + self.remote_input_a.next() + nose.tools.eq_(self.remote_input_a.data.value, 20) + + assert not self.remote_input_a.hasMoreData() + + + def test_group_has_more_data(self): + assert self.remote_group.hasMoreData() + + + def test_group_next(self): + self.remote_group.next() + nose.tools.eq_(self.remote_input_a.data.value, 10) + nose.tools.eq_(self.remote_input_b.data.value, 100) + + + def test_group_full_cycle(self): + assert self.remote_group.hasMoreData() + self.remote_group.next() + nose.tools.eq_(self.remote_input_a.data.value, 10) + nose.tools.eq_(self.remote_input_b.data.value, 100) + + assert self.remote_group.hasMoreData() + self.remote_group.next() + nose.tools.eq_(self.remote_input_a.data.value, 20) + nose.tools.eq_(self.remote_input_b.data.value, 200) + + assert not self.remote_group.hasMoreData() diff --git a/beat/core/utils.py b/beat/core/utils.py index 0811e739aceb009bdd83f08f34dff39e7b124c02..11889e5827b76cafa2fdaac907f8a08147509040 100644 --- a/beat/core/utils.py +++ b/beat/core/utils.py @@ -27,14 +27,12 @@ import os -import shutil import tempfile -import collections import numpy import simplejson -import six +from beat.backend.python.utils import * from . import hash @@ -45,6 +43,7 @@ def temporary_directory(prefix='beat_'): return tempfile.mkdtemp(prefix=prefix) + def hashed_or_simple(prefix, what, path, suffix='.json'): """Returns a hashed path or simple path depending on where the resource is""" @@ -55,50 +54,6 @@ def hashed_or_simple(prefix, what, path, suffix='.json'): return os.path.join(prefix, what, path) -def safe_rmfile(f): - """Safely removes a file from the disk""" - - if os.path.exists(f): os.unlink(f) - - -def safe_rmdir(f): - """Safely removes the directory containg a given file from the disk""" - - d = os.path.dirname(f) - if not os.path.exists(d): return - if not os.listdir(d): os.rmdir(d) - - -def extension_for_language(language): - """Returns the preferred extension for a given programming language - - The set of languages supported must match those declared in our - ``common.json`` schema. - - Parameters: - - language (str) The language for which you'd like to get the extension for. - - - Returns: - - str: The extension for the given language, including a leading ``.`` (dot) - - - Raises: - - KeyError: If the language is not defined in our internal dictionary. - - """ - - return dict( - unknown = '', - cxx = '.so', - matlab = '.m', - python = '.py', - r = '.r', - )[language] - class NumpyJSONEncoder(simplejson.JSONEncoder): """Encodes numpy arrays and scalars @@ -118,59 +73,6 @@ class NumpyJSONEncoder(simplejson.JSONEncoder): return simplejson.JSONEncoder.default(self, obj) -class File(object): - """User helper to read and write file objects""" - - - def __init__(self, path, binary=False): - - self.path = path - self.binary = binary - - - def exists(self): - - return os.path.exists(self.path) - - - def load(self): - - mode = 'rb' if self.binary else 'rt' - with open(self.path, mode) as f: return f.read() - - - def try_load(self): - - if os.path.exists(self.path): - return self.load() - return None - - - def backup(self): - - if not os.path.exists(self.path): return #no point in backing-up - backup = self.path + '~' - if os.path.exists(backup): os.remove(backup) - shutil.copy(self.path, backup) - - - def save(self, contents): - - d = os.path.dirname(self.path) - if not os.path.exists(d): os.makedirs(d) - - if os.path.exists(self.path): self.backup() - - mode = 'wb' if self.binary else 'wt' - with open(self.path, mode) as f: f.write(contents) - - - def remove(self): - - safe_rmfile(self.path) - safe_rmfile(self.path + '~') #backup - safe_rmdir(self.path) #remove containing directory - def uniq(seq): '''Order preserving (very fast) uniq function for sequences''' @@ -183,129 +85,3 @@ def uniq(seq): result.append(item) return result - - -class Storage(object): - """Resolves paths for objects that provide only a description - - Parameters: - - prefix (str): Establishes the prefix of your installation. - - name (str): The name of the database object in the format - ``<name>/<version>``. - - """ - - def __init__(self, path): - - self.path = path - self.json = File(self.path + '.json') - self.doc = File(self.path + '.rst') - - def hash(self, description='description'): - """The 64-character hash of the database declaration JSON""" - return hash.hashJSONFile(self.json.path, description) - - def exists(self): - """If the database declaration file exists""" - return self.json.exists() - - def load(self): - """Loads the JSON declaration as a file""" - tp = collections.namedtuple('Storage', ['declaration', 'description']) - return tp(self.json.load(), self.doc.try_load()) - - def save(self, declaration, description=None): - """Saves the JSON declaration as files""" - if description: self.doc.save(description.encode('utf8')) - if not isinstance(declaration, six.string_types): - declaration = simplejson.dumps(declaration, indent=4) - self.json.save(declaration) - - def remove(self): - """Removes the object from the disk""" - self.json.remove() - self.doc.remove() - - -class CodeStorage(object): - """Resolves paths for objects that provide a description and code - - Parameters: - - prefix (str): Establishes the prefix of your installation. - - name (str): The name of the database object in the format - ``<name>/<version>``. - - language (str): One of the valdid programming languages - - """ - - def __init__(self, path, language=None): - - self.path = path - self.json = File(self.path + '.json') - self.doc = File(self.path + '.rst') - - self._language = language or self.__auto_discover_language() - self.code = File(self.path + \ - extension_for_language(self._language), binary=True) - - def __auto_discover_language(self, json=None): - """Discovers and sets the language from its own JSON descriptor""" - try: - text = json or self.json.load() - json = simplejson.loads(text) - return json['language'] - except IOError: - return 'unknown' - - @property - def language(self): - return self._language - - @language.setter - def language(self, value): - self._language = value - self.code = File(self.path + extension_for_language(self._language), - binary=True) - - def hash(self): - """The 64-character hash of the database declaration JSON""" - - if self.code.exists(): - return hash.hash(dict( - json=hash.hashJSONFile(self.json.path, 'description'), - code=hash.hashFileContents(self.code.path), - )) - else: - return hash.hash(dict( - json=hash.hashJSONFile(self.json.path, 'description'), - )) - - def exists(self): - """If the database declaration file exists""" - return self.json.exists() and self.code.exists() - - def load(self): - """Loads the JSON declaration as a file""" - tp = collections.namedtuple('CodeStorage', - ['declaration', 'code', 'description']) - return tp(self.json.load(), self.code.try_load(), self.doc.try_load()) - - def save(self, declaration, code=None, description=None): - """Saves the JSON declaration and the code as files""" - if description: self.doc.save(description.encode('utf8')) - self.json.save(declaration) - if code: - if self._language == 'unknown': - self.language = self.__auto_discover_language(declaration) - self.code.save(code) - - def remove(self): - """Removes the object from the disk""" - self.json.remove() - self.code.remove() - self.doc.remove()