diff --git a/beat/backend/python/algorithm.py b/beat/backend/python/algorithm.py index ae577b27aaafe8c5cdba096340740c9f8a7b4280..70d477e33d073690c3eccff96d0a9d473bd91087 100755 --- a/beat/backend/python/algorithm.py +++ b/beat/backend/python/algorithm.py @@ -355,6 +355,7 @@ class Algorithm(object): self.data = simplejson.load(f) self.code_path = self.storage.code.path + self.code = self.storage.code.load() self.groups = self.data['groups'] @@ -772,3 +773,88 @@ class Algorithm(object): raise #just re-raise the user exception return Runner(self.__module, klass, self, exc) + + + def json_dumps(self, indent=4): + """Dumps the JSON declaration of this object in a string + + + Parameters: + + indent (int): The number of indentation spaces at every indentation level + + + Returns: + + str: The JSON representation for this object + + """ + + return simplejson.dumps(self.data, indent=indent, + cls=utils.NumpyJSONEncoder) + + + def __str__(self): + return self.json_dumps() + + + def write(self, storage=None): + """Writes contents to prefix location + + Parameters: + + storage (Storage, optional): If you pass a new storage, then this object + will be written to that storage point rather than its default. + + """ + + if self.data['language'] == 'unknown': + raise RuntimeError("algorithm has no programming language set") + + if storage is None: + if not self._name: + raise RuntimeError("algorithm has no name") + storage = self.storage #overwrite + + storage.save(str(self), self.code, self.description) + + + def export(self, prefix): + """Recursively exports itself into another prefix + + Dataformats and associated libraries are also copied. + + + Parameters: + + prefix (str): A path to a prefix that must different then my own. + + + Returns: + + None + + + Raises: + + RuntimeError: If prefix and self.prefix point to the same directory. + + """ + + if not self._name: + raise RuntimeError("algorithm has no name") + + if not self.valid: + raise RuntimeError("algorithm is not valid") + + if os.path.samefile(prefix, self.prefix): + raise RuntimeError("Cannot export algorithm to the same prefix (%s == " \ + "%s)" % (prefix, self.prefix)) + + for k in self.libraries.values(): + k.export(prefix) + + for k in self.dataformats.values(): + k.export(prefix) + + self.write(Storage(prefix, self.name, self.language)) diff --git a/beat/backend/python/data.py b/beat/backend/python/data.py index 356189419c60584ebad571ec26fefc2305227692..b7883b5d99744fcefec949e1d2b31d1202c88b57 100755 --- a/beat/backend/python/data.py +++ b/beat/backend/python/data.py @@ -298,6 +298,9 @@ class CachedFileLoader(object): (self.filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \ getAllFilenames(filename, start_index, end_index) + if len(self.filenames) == 0: + return False + check_consistency(self.filenames, data_checksum_filenames) @@ -487,6 +490,10 @@ class DataSink(object): pass + def close(self): + pass + + #---------------------------------------------------------- @@ -633,91 +640,21 @@ class CachedDataSink(DataSink): """ def __init__(self): - self.filename = None - self.process_id = None - self.split_id = None - self.max_size = None - - self._nb_bytes_written = 0 - self._write_duration = 0 - self._nb_bytes_written_split = 0 - - self._new_file = False - - self._cur_filename = None - self._cur_file = None - self._cur_indexname = None - self._cur_index = None - - self._cur_start_index = None - self._cur_end_index = None - self._filenames = [] - self._filenames_tmp = [] - - self._tmp_ext = '.tmp' - self.encoding = None self.dataformat = None + self.start_index = None + self.end_index = None - def _curTmpFilenameWithSplit(self): - - filename, data_ext = os.path.splitext(self.filename) - dirname = os.path.dirname(filename) - basename = os.path.basename(filename) - fd, tmp_file = tempfile.mkstemp( - dir=dirname, - prefix=basename+'.' + str(self.process_id)+'.'+ str(self.split_id)+'_', - suffix=data_ext + self._tmp_ext, - ) - os.close(fd) # Preserve only the name - os.unlink(tmp_file) - return tmp_file - - def _curFilenameWithIndices(self): + self.data_file = None + self.index_file = None + self.last_written_data_index = None - basename = os.path.basename(self.filename) - basename, data_ext = os.path.splitext(basename) - dirname = os.path.dirname(self.filename) - return os.path.join(dirname, basename + '.' + str(self._cur_start_index) + '.' + str(self._cur_end_index) + data_ext) + self.nb_bytes_written = 0 + self.write_duration = 0 - def _tmpIndexFilenameFromTmpFilename(self, tmp_filename): - return os.path.splitext(os.path.splitext(tmp_filename)[0])[0] + '.index' + self._tmp_ext - - def _indexFilenameFromFilename(self, filename): - return os.path.splitext(filename)[0] + '.index' - - def _openAndWriteHeader(self): - """Write the header of the current file""" - - # Close current file if open - self._close_current() - - # Open new file in writing mode - self._cur_filename = self._curTmpFilenameWithSplit() - self._cur_indexname = \ - self._tmpIndexFilenameFromTmpFilename(self._cur_filename) - self._filenames_tmp.append(self._cur_filename) - try: - self._cur_file = open(self._cur_filename, 'wb') - self._cur_index = open(self._cur_indexname, 'wt') - except: - return - - # Write dataformat - self._cur_file.write(six.b('%s\n%s\n' % \ - (self.encoding, self.dataformat.name))) - self._cur_file.flush() - - # Reset few flags - self._cur_start_index = None - self._cur_end_index = None - self._new_file = False - self._nb_bytes_written_split = 0 - - def setup(self, filename, dataformat, encoding='binary', process_id=0, - max_size=0): + def setup(self, filename, dataformat, start_index, end_index, encoding='binary'): """Configures the data sink Parameters: @@ -734,127 +671,82 @@ class CachedDataSink(DataSink): """ + # Close current file if open + self.close() + if encoding not in ('binary', 'json'): - raise RuntimeError("valid formats for data writting are 'binary' " + raise RuntimeError("valid formats for data writing are 'binary' " "or 'json': the format `%s' is invalid" % format) if dataformat.name == '__unnamed_dataformat__': - raise RuntimeError("cannot record data using an unnammed data format") + raise RuntimeError("cannot record data using an unnamed data format") - self.filename = filename - self.process_id = process_id - self.split_id = 0 - self.max_size = max_size + filename, data_ext = os.path.splitext(filename) - self._nb_bytes_written = 0 - self._write_duration = 0 - self._new_file = True + self.filename = '%s.%d.%d%s' % (filename, start_index, end_index, data_ext) + self.encoding = encoding + self.dataformat = dataformat + self.start_index = start_index + self.end_index = end_index - self._cur_filename = None - self._cur_file = None - self._cur_indexname = None - self._cur_index = None - self._cur_start_index = None - self._cur_end_index = None + self.nb_bytes_written = 0 + self.write_duration = 0 + self.last_written_data_index = None - self._filenames = [] - self._filenames_tmp = [] + try: + self.data_file = open(self.filename, 'wb') + self.index_file = open(self.filename.replace('.data', '.index'), 'wt') + except: + return False - self.dataformat = dataformat - self.encoding = encoding + # Write the dataformat + self.data_file.write(six.b('%s\n%s\n' % (self.encoding, self.dataformat.name))) + self.data_file.flush() return True - def _close_current(self): + + def close(self): """Closes the data sink """ - if self._cur_file is not None: - self._cur_file.close() - self._cur_index.close() + if self.data_file is not None: + self.data_file.close() + self.index_file.close() - # If file is empty, remove it - if self._cur_start_index is None or self._cur_end_index is None: + # If file is not complete, delete it + if (self.last_written_data_index is None) or \ + (self.last_written_data_index < self.end_index): try: - os.remove(self._cur_filename) - os.remove(self._cur_index) + os.remove(self.filename) + os.remove(self.filename.replace('.data', '.index')) + return True except: return False - self._filenames_tmp.pop() - - # Otherwise, append final filename to list - else: - self._filenames.append(self._curFilenameWithIndices()) - - self._cur_filename = None - self._cur_file = None - self._cur_indexname = None - self._cur_index = None - - def close(self): - """Move the files to final location - """ - - self._close_current() - assert len(self._filenames_tmp) == len(self._filenames) - for i in range(len(self._filenames_tmp)): - - os.rename(self._filenames_tmp[i], self._filenames[i]) - tmp_indexname = \ - self._tmpIndexFilenameFromTmpFilename(self._filenames_tmp[i]) - final_indexname = self._indexFilenameFromFilename(self._filenames[i]) - os.rename(tmp_indexname, final_indexname) - - # creates the checksums for all data and indexes - chksum_data = hashFileContents(self._filenames[i]) - with open(self._filenames[i] + '.checksum', 'wt') as f: + # Creates the checksums for all data and indexes + chksum_data = hashFileContents(self.filename) + with open(self.filename + '.checksum', 'wt') as f: f.write(chksum_data) - chksum_index = hashFileContents(final_indexname) - with open(final_indexname + '.checksum', 'wt') as f: - f.write(chksum_index) - - self._cur_filename = None - self._cur_file = None - self._cur_indexname = None - self._cur_index = None - self._cur_start_index = None - self._cur_end_index = None - self._filenames = [] - self._filenames_tmp = [] + index_filename = self.filename.replace('.data', '.index') + chksum_index = hashFileContents(index_filename) + with open(index_filename + '.checksum', 'wt') as f: + f.write(chksum_index) - def reset(self): - """Move the files to final location - """ + self.data_file = None + self.index_file = None + self.last_written_data_index = None - self._close_current() - assert len(self._filenames_tmp) == len(self._filenames) - for i in range(len(self._filenames_tmp)): - try: - os.remove(self._filenames_tmp[i]) - tmp_indexname = \ - self._tmpIndexFilenameFromTmpFilename(self._filenames_tmp[i]) - os.remove(tmp_indexname) - except: - return False - - self._cur_filename = None - self._cur_file = None - self._cur_indexname = None - self._cur_index = None + return True - self._cur_start_index = None - self._cur_end_index = None - self._filenames = [] - self._filenames_tmp = [] def __del__(self): - """Make sure the files are close and renamed when the object is deleted + """Make sure the files are closed when the object is deleted """ - self.close() + def write(self, data, start_data_index, end_data_index): """Writes a block of data to the filesystem @@ -868,8 +760,8 @@ class CachedDataSink(DataSink): """ + # If the user passed a dictionary - convert it if isinstance(data, dict): - # the user passed a dictionary - must convert data = self.dataformat.type(**data) else: # Checks that the input data conforms to the expected format @@ -877,57 +769,43 @@ class CachedDataSink(DataSink): raise TypeError("input data uses format `%s' while this sink " "expects `%s'" % (data.__class__._name, self.dataformat)) - # If the flag new_file is set, open new file and write header - if self._new_file: - self._openAndWriteHeader() - - if self._cur_file is None: - raise RuntimeError("no destination file") + if self.data_file is None: + raise RuntimeError("No destination file") - # encoding happens here + # Encoding if self.encoding == 'binary': encoded_data = data.pack() else: from .utils import NumpyJSONEncoder encoded_data = json.dumps(data.as_dict(), indent=4, cls=NumpyJSONEncoder) - # adds a new line by the end of the encoded data, for clarity + # Adds a new line by the end of the encoded data encoded_data += six.b('\n') informations = six.b('%d %d %d\n' % (start_data_index, - end_data_index, len(encoded_data))) + end_data_index, len(encoded_data))) t1 = time.time() - self._cur_file.write(informations + encoded_data) - self._cur_file.flush() + self.data_file.write(informations + encoded_data) + self.data_file.flush() indexes = '%d %d\n' % (start_data_index, end_data_index) - self._cur_index.write(indexes) - self._cur_index.flush() + self.index_file.write(indexes) + self.index_file.flush() t2 = time.time() - self._nb_bytes_written += \ - len(informations) + len(encoded_data) + len(indexes) - self._nb_bytes_written_split += \ - len(informations) + len(encoded_data) + len(indexes) - self._write_duration += t2 - t1 + self.nb_bytes_written += len(informations) + len(encoded_data) + len(indexes) + self.write_duration += t2 - t1 - # Update start and end indices - if self._cur_start_index is None: - self._cur_start_index = start_data_index - self._cur_end_index = end_data_index + self.last_written_data_index = end_data_index - # If file size exceeds max, sets the flag to create a new file - if self.max_size != 0 and self._nb_bytes_written >= self.max_size: - self._new_file = True - self.split_id += 1 def statistics(self): """Return the statistics about the number of bytes written to the cache""" + return (self.nb_bytes_written, self.write_duration) - return (self._nb_bytes_written, self._write_duration) def isConnected(self): return (self.filename is not None) diff --git a/beat/backend/python/dataformat.py b/beat/backend/python/dataformat.py old mode 100644 new mode 100755 index b6c338c826c6dcba22d9ba51920b5190af8f326e..b394494f7ecf337e64996b93512241286a3638b6 --- a/beat/backend/python/dataformat.py +++ b/beat/backend/python/dataformat.py @@ -430,3 +430,81 @@ class DataFormat(object): return self.isparent(other.referenced[other.extends]) return False + + + def json_dumps(self, indent=4): + """Dumps the JSON declaration of this object in a string + + Parameters: + + indent (int): The number of indentation spaces at every indentation level + + + Returns: + + str: The JSON representation for this object + + """ + + return simplejson.dumps(self.data, indent=indent, + cls=utils.NumpyJSONEncoder) + + + def __str__(self): + return self.json_dumps() + + + def write(self, storage=None): + """Writes contents to prefix location + + Parameters: + + storage (Storage, optional): If you pass a new storage, then this object + will be written to that storage point rather than its default. + + """ + + if storage is None: + if not self._name: + raise RuntimeError("dataformat has no name") + storage = self.storage #overwrite + + storage.save(str(self), self.description) + + + def export(self, prefix): + """Recursively exports itself into another prefix + + Other required dataformats are also copied. + + + Parameters: + + prefix (str): A path to a prefix that must different then my own. + + + Returns: + + None + + + Raises: + + RuntimeError: If prefix and self.prefix point to the same directory. + + """ + + if not self._name: + raise RuntimeError("dataformat has no name") + + if not self.valid: + raise RuntimeError("dataformat is not valid") + + if os.path.samefile(prefix, self.prefix): + raise RuntimeError("Cannot dataformat object to the same prefix (%s " \ + "== %s)" % (prefix, self.prefix)) + + for k in self.referenced.values(): + k.export(prefix) + + self.write(Storage(prefix, self.name)) diff --git a/beat/backend/python/dbexecution.py b/beat/backend/python/dbexecution.py index 13b963ccd09bc90b7e9112e472629fccd1f1f56f..9b4babd78670a3725d6475b7280b9bdaf8dd9e54 100755 --- a/beat/backend/python/dbexecution.py +++ b/beat/backend/python/dbexecution.py @@ -260,9 +260,8 @@ class DBExecutor(object): group.add(inputs.Input(name, self.dataformat_cache[input_dataformat_name], data_source)) - def process(self, zmq_context, zmq_socket): - - self.handler = message_handler.MessageHandler(self.input_list, zmq_context, zmq_socket) + def process(self, address): + self.handler = message_handler.MessageHandler(address, inputs=self.input_list) self.handler.start() @@ -275,6 +274,7 @@ class DBExecutor(object): def wait(self): self.handler.join() + self.handler.destroy() self.handler = None diff --git a/beat/backend/python/executor.py b/beat/backend/python/executor.py index 071980bba87b1e58dd196b296bb21b00b1aea501..df0b929aff38dfb2e23ec76054c69374a81d6562 100755 --- a/beat/backend/python/executor.py +++ b/beat/backend/python/executor.py @@ -136,16 +136,23 @@ class Executor(object): if not self.input_list or not self.output_list: raise RuntimeError("I/O for execution block has not yet been set up") - using_output = self.output_list[0] if self.analysis else self.output_list - while self.input_list.hasMoreData(): main_group = self.input_list.main_group main_group.restricted_access = False main_group.next() main_group.restricted_access = True - if not self.runner.process(self.input_list, using_output): + + if self.analysis: + result = self.runner.process(inputs=self.input_list, output=self.output_list[0]) + else: + result = self.runner.process(inputs=self.input_list, outputs=self.output_list) + + if not result: return False + for output in self.output_list: + output.close() + missing_data_outputs = [x for x in self.output_list if x.isDataMissing()] if missing_data_outputs: diff --git a/beat/backend/python/helpers.py b/beat/backend/python/helpers.py index 80a94454912040b90850283891bddae9899dc149..9ac010d8a8cf2cbb114138bc125ec3c4aa6d99e6 100755 --- a/beat/backend/python/helpers.py +++ b/beat/backend/python/helpers.py @@ -247,14 +247,30 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp if exception.errno != errno.EEXIST: raise + if start_index is None: + for k, v in config['inputs'].items(): + if v['channel'] == config['channel']: + input_path = os.path.join(cache_root, v['path'] + '.data') + break + + (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \ + data.getAllFilenames(input_path) + + end_indices = [ int(x.split('.')[-2]) for x in indices_filenames ] + end_indices.sort() + + start_index = 0 + end_index = end_indices[-1] + data_sink = data.CachedDataSink() data_sinks.append(data_sink) status = data_sink.setup( filename=path, dataformat=dataformat, - encoding='binary', - max_size=0, # in bytes, for individual file chunks + start_index=start_index, + end_index=end_index, + encoding='binary' ) if not status: @@ -262,7 +278,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp output_list.add(outputs.Output(name, data_sink, synchronization_listener=synchronization_listener, - force_start_index=start_index or 0) + force_start_index=start_index) ) if 'result' not in config: diff --git a/beat/backend/python/library.py b/beat/backend/python/library.py old mode 100644 new mode 100755 index cf9bad8cf24fc2afc8265f8a3cdd90f8b9252be2..51c41aa91af0b4145e2b1c1689d19fdc01f9a2a9 --- a/beat/backend/python/library.py +++ b/beat/backend/python/library.py @@ -292,3 +292,85 @@ class Library(object): raise RuntimeError("library has no name") return self.storage.hash() + + + def json_dumps(self, indent=4): + """Dumps the JSON declaration of this object in a string + + + Parameters: + + indent (int): The number of indentation spaces at every indentation level + + + Returns: + + str: The JSON representation for this object + + """ + + return simplejson.dumps(self.data, indent=indent, + cls=utils.NumpyJSONEncoder) + + + def __str__(self): + return self.json_dumps() + + + def write(self, storage=None): + """Writes contents to prefix location. + + Parameters: + + storage (Storage, optional): If you pass a new storage, then this object + will be written to that storage point rather than its default. + + """ + + if self.data['language'] == 'unknown': + raise RuntimeError("library has no programming language set") + + if storage is None: + if not self._name: + raise RuntimeError("library has no name") + storage = self.storage #overwrite + + storage.save(str(self), self.code, self.description) + + + def export(self, prefix): + """Recursively exports itself into another prefix + + Other required libraries are also copied. + + + Parameters: + + prefix (str): A path to a prefix that must different then my own. + + + Returns: + + None + + + Raises: + + RuntimeError: If prefix and self.prefix point to the same directory. + + """ + + if not self._name: + raise RuntimeError("library has no name") + + if not self.valid: + raise RuntimeError("library is not valid") + + if os.path.samefile(prefix, self.prefix): + raise RuntimeError("Cannot library object to the same prefix (%s == " \ + "%s)" % (prefix, self.prefix)) + + for k in self.libraries.values(): + k.export(prefix) + + self.write(Storage(prefix, self.name)) diff --git a/beat/backend/python/message_handler.py b/beat/backend/python/message_handler.py index 09c6ddb99df2c3792045571a63782aeab7cdcb0d..368fb28724b2293593d30e52b4dc8dc39c7ffde1 100755 --- a/beat/backend/python/message_handler.py +++ b/beat/backend/python/message_handler.py @@ -39,13 +39,9 @@ from .inputs import RemoteException class MessageHandler(threading.Thread): - '''A 0MQ message handler for our communication with other processes + '''A 0MQ message handler for our communication with other processes''' - Support for more messages can be implemented by subclassing this class. - This one only support input-related messages. - ''' - - def __init__(self, input_list, zmq_context, zmq_socket, kill_callback=None): + def __init__(self, host_address, inputs=None, outputs=None, kill_callback=None): super(MessageHandler, self).__init__() @@ -56,14 +52,29 @@ class MessageHandler(threading.Thread): self.must_kill = threading.Event() self.must_kill.clear() - # Starts our 0MQ server - self.context = zmq_context - self.socket = zmq_socket + # Either starts a 0MQ server or connect to an existing one + self.context = zmq.Context() + self.socket = self.context.socket(zmq.PAIR) + + if not host_address.startswith('tcp://'): + self.address = 'tcp://' + host_address + else: + self.address = host_address + + if len(self.address.split(':')) == 2: + port = self.socket.bind_to_random_port(self.address, min_port=50000) + self.address += ':%d' % port + logger.debug("zmq server bound to '%s'", self.address) + else: + self.socket.connect(self.address) + logger.debug("connected to '%s'", self.address) self.poller = zmq.Poller() self.poller.register(self.socket, zmq.POLLIN) - self.input_list = input_list + # Initialisations + self.input_list = inputs + self.output_list = outputs self.system_error = '' self.user_error = '' @@ -73,15 +84,34 @@ class MessageHandler(threading.Thread): # implementations self.callbacks = dict( - nxt = self.next, - hmd = self.has_more_data, don = self.done, err = self.error, ) + if self.input_list is not None: + self.callbacks.update(dict( + nxt = self.next, + hmd = self.has_more_data, + )) - def run(self): + if self.output_list is not None: + self.callbacks.update(dict( + wrt = self.write, + oic = self.output_is_connected, + )) + + + def destroy(self): + self.socket.setsockopt(zmq.LINGER, 0) + self.socket.close() + self.context.destroy() + + + def __str__(self): + return 'MessageHandler(%s)' % self.address + + def run(self): logger.debug("0MQ server thread started") while not self.stop.is_set(): #keep on @@ -161,7 +191,6 @@ class MessageHandler(threading.Thread): def _get_input_candidate(self, channel, name): - channel_group = self.input_list.group(channel) retval = channel_group[name] if retval is None: @@ -170,11 +199,55 @@ class MessageHandler(threading.Thread): return retval + def _get_output_candidate(self, name): + retval = self.output_list[name] + if retval is None: + raise RuntimeError("Could not find output `%s'" % name) + return retval + + + def _acknowledge(self): + + logger.debug('send: ack') + self.socket.send('ack') + logger.debug('setting stop condition for 0MQ server thread') + self.stop.set() + + + def done(self, statistics=None): + """Syntax: don""" + + logger.debug('recv: don %s', statistics) + + if statistics is not None: + self.statistics = simplejson.loads(statistics) + + self._acknowledge() + + + def error(self, t, msg): + """Syntax: err type message""" + + logger.debug('recv: err %s <msg> (size=%d)', t, len(msg)) + + if t == 'usr': + self.user_error = msg + else: + self.system_error = msg + + self.statistics = dict(network=dict(wait_time=0.)) + self._acknowledge() + + def next(self, channel, name): """Syntax: nxt channel name ...""" logger.debug('recv: nxt %s %s', channel, name) + if self.input_list is None: + message = 'Unexpected message received: nxt %s %s' % (channel, name) + raise RemoteException('sys', message) + channel_group = self.input_list.group(channel) restricted = channel_group.restricted_access @@ -191,7 +264,6 @@ class MessageHandler(threading.Thread): "`%s' on input `%s', but it is over (no more data). This " \ "normally indicates a programming error on the user " \ "side." % (channel, name) - self.user_error += message + '\n' raise RuntimeError(message) if isinstance(input_candidate.data, baseformat.baseformat): @@ -211,6 +283,11 @@ class MessageHandler(threading.Thread): """Syntax: hmd channel name""" logger.debug('recv: hmd %s %s', channel, name) + + if self.input_list is None: + message = 'Unexpected message received: hmd %s %s' % (channel, name) + raise RemoteException('sys', message) + input_candidate = self._get_input_candidate(channel, name) what = 'tru' if input_candidate.hasMoreData() else 'fal' @@ -218,37 +295,43 @@ class MessageHandler(threading.Thread): self.socket.send(what) - def _acknowledge(self): + def write(self, name, end_data_index, packed): + """Syntax: wrt output data""" - logger.debug('send: ack') - self.socket.send('ack') - logger.debug('setting stop condition for 0MQ server thread') - self.stop.set() + end_data_index = int(end_data_index) + logger.debug('recv: wrt %s %d <bin> (size=%d)', name, end_data_index, len(packed)) - def done(self, statistics=None): - """Syntax: don""" + if self.output_list is None: + message = 'Unexpected message received: wrt %s %d <bin> (size=%d)' % (name, end_data_index, len(packed)) + raise RemoteException('sys', message) - logger.debug('recv: don %s', statistics) + # Get output object + output_candidate = self._get_output_candidate(name) + if output_candidate is None: + raise RuntimeError("Could not find output `%s' to write to" % name) - if statistics is not None: - self.statistics = simplejson.loads(statistics) + data = output_candidate.data_sink.dataformat.type() + data.unpack(packed) + output_candidate.write(data, end_data_index=end_data_index) - self._acknowledge() + logger.debug('send: ack') + self.socket.send('ack') - def error(self, t, msg): - """Syntax: err type message""" + def output_is_connected(self, name): + """Syntax: oic output""" - logger.debug('recv: err %s <msg> (size=%d)', t, len(msg)) + logger.debug('recv: oic %s', name) - if t == 'usr': - self.user_error = msg - else: - self.system_error = msg + if self.output_list is None: + message = 'Unexpected message received: oic %s' % name + raise RemoteException('sys', message) - self.statistics = dict(network=dict(wait_time=0.)) - self._acknowledge() + output_candidate = self._get_output_candidate(name) + what = 'tru' if output_candidate.isConnected() else 'fal' + logger.debug('send: %s', what) + self.socket.send(what) def kill(self): diff --git a/beat/backend/python/outputs.py b/beat/backend/python/outputs.py index d8b7a4242f5e9e53d4d6f3f31f088a79a9902ba3..cc6ed7b9ba814fb6e68a95cc0105d196e2682272 100755 --- a/beat/backend/python/outputs.py +++ b/beat/backend/python/outputs.py @@ -144,6 +144,10 @@ class BaseOutput(object): return end_data_index + def close(self): + pass + + #---------------------------------------------------------- @@ -233,6 +237,10 @@ class Output(BaseOutput): return self.data_sink.isConnected() + def close(self): + self.data_sink.close() + + #---------------------------------------------------------- diff --git a/beat/backend/python/test/test_algorithm.py b/beat/backend/python/test/test_algorithm.py index 2d84296dcefcfe4da02a7cfef4eb8a33fbb05c5c..f326f530b24d47d2c99a8613ae8ca112b377e403 100644 --- a/beat/backend/python/test/test_algorithm.py +++ b/beat/backend/python/test/test_algorithm.py @@ -537,7 +537,8 @@ class TestExecutionBase(unittest.TestCase): self.assertTrue(dataformat.valid) data_sink = CachedDataSink() - self.assertTrue(data_sink.setup(self.filenames[input_name], dataformat)) + self.assertTrue(data_sink.setup(self.filenames[input_name], dataformat, + indices[0][0], indices[-1][1])) for i in indices: data = dataformat.type() diff --git a/beat/backend/python/test/test_data.py b/beat/backend/python/test/test_data.py index bdd299776f2b117e169f91cb7be3d75fbf6b7ab1..286c0b5fcddf2bbd5be22d2e8c54b8cbd2ead1a0 100644 --- a/beat/backend/python/test/test_data.py +++ b/beat/backend/python/test/test_data.py @@ -116,7 +116,7 @@ class TestCachedDataBase(unittest.TestCase): self.assertTrue(dataformat.valid) data_sink = CachedDataSink() - self.assertTrue(data_sink.setup(self.filename, dataformat)) + self.assertTrue(data_sink.setup(self.filename, dataformat, start_index, end_index)) all_data = [] for i in range(start_index, end_index + 1): @@ -352,7 +352,7 @@ class TestDataSink(TestCachedDataBase): self.assertTrue(dataformat.valid) data_sink = CachedDataSink() - self.assertTrue(data_sink.setup(self.filename, dataformat)) + self.assertTrue(data_sink.setup(self.filename, dataformat, 0, 10)) #---------------------------------------------------------- diff --git a/beat/backend/python/test/test_data_loaders.py b/beat/backend/python/test/test_data_loaders.py index 6823fdfa8ce786ea216d532f9fa45d7d6bf3c56a..4fc011f07c476f27fa0a4b6473c4deb0751d4129 100644 --- a/beat/backend/python/test/test_data_loaders.py +++ b/beat/backend/python/test/test_data_loaders.py @@ -69,7 +69,8 @@ class DataLoaderBaseTest(unittest.TestCase): self.assertTrue(dataformat.valid) data_sink = CachedDataSink() - self.assertTrue(data_sink.setup(self.filenames[input_name], dataformat)) + self.assertTrue(data_sink.setup(self.filenames[input_name], dataformat, + indices[0][0], indices[-1][1])) for i in indices: data = dataformat.type() diff --git a/beat/backend/python/test/test_dbexecution.py b/beat/backend/python/test/test_dbexecution.py index 87568808042ac5c166d7a26b11773b206561af72..f1aca8e3f53fb63a8cd297e54e89de864592115f 100644 --- a/beat/backend/python/test/test_dbexecution.py +++ b/beat/backend/python/test/test_dbexecution.py @@ -123,20 +123,18 @@ class HostSide(object): class ContainerSide(object): - def __init__(self, zmq_context, address): + def __init__(self, address): dataformat_cache = {} database_cache = {} self.dbexecutor = DBExecutor(prefix, CONFIGURATION, - dataformat_cache, database_cache) - assert self.dbexecutor.valid, '\n * %s' % '\n * '.join(self.dbexecutor.errors) + dataformat_cache, database_cache) - self.socket = zmq_context.socket(zmq.PAIR) - self.socket.connect(address) + assert self.dbexecutor.valid, '\n * %s' % '\n * '.join(self.dbexecutor.errors) with self.dbexecutor: - self.dbexecutor.process(zmq_context, self.socket) + self.dbexecutor.process(address) def wait(self): @@ -153,7 +151,7 @@ class TestExecution(unittest.TestCase): context = zmq.Context() host = HostSide(context) - container = ContainerSide(context, host.address) + container = ContainerSide(host.address) while host.group.hasMoreData(): host.group.next() diff --git a/beat/backend/python/test/test_message_handler.py b/beat/backend/python/test/test_message_handler.py index 37dd3fbcefb1502e1599207aa3694a50c7a873fe..2cfb3d2b464d87bfde913dda83f7688f2ab7ca91 100644 --- a/beat/backend/python/test/test_message_handler.py +++ b/beat/backend/python/test/test_message_handler.py @@ -61,19 +61,12 @@ class TestMessageHandlerBase(unittest.TestCase): self.input_list = InputList() self.input_list.add(group) - - self.server_context = zmq.Context() - server_socket = self.server_context.socket(zmq.PAIR) - address = 'tcp://127.0.0.1' - port = server_socket.bind_to_random_port(address, min_port=50000) - address += ':%d' % port - - self.message_handler = MessageHandler(self.input_list, self.server_context, server_socket) + self.message_handler = MessageHandler('127.0.0.1', inputs=self.input_list) self.client_context = zmq.Context() client_socket = self.client_context.socket(zmq.PAIR) - client_socket.connect(address) + client_socket.connect(self.message_handler.address) self.remote_group = InputGroup('channel', restricted_access=False) @@ -369,19 +362,12 @@ class TestMessageHandlerErrorHandling(unittest.TestCase): self.input_list = InputList() self.input_list.add(group) - - self.server_context = zmq.Context() - server_socket = self.server_context.socket(zmq.PAIR) - address = 'tcp://127.0.0.1' - port = server_socket.bind_to_random_port(address, min_port=50000) - address += ':%d' % port - - self.message_handler = MessageHandler(self.input_list, self.server_context, server_socket) + self.message_handler = MessageHandler('127.0.0.1', inputs=self.input_list) self.client_context = zmq.Context() client_socket = self.client_context.socket(zmq.PAIR) - client_socket.connect(address) + client_socket.connect(self.message_handler.address) self.remote_input = RemoteInput('in', dataformat, client_socket) @@ -397,6 +383,7 @@ class TestMessageHandlerErrorHandling(unittest.TestCase): def tearDown(self): self.message_handler.kill() self.message_handler.join() + self.message_handler.destroy() self.message_handler = None diff --git a/beat/backend/python/utils.py b/beat/backend/python/utils.py index 05ba3efafb0aaa7dbbd776bb482567c51d7e412d..04f096f29d52af7e07a9cc915d43857591210aba 100755 --- a/beat/backend/python/utils.py +++ b/beat/backend/python/utils.py @@ -297,3 +297,24 @@ class CodeStorage(object): self.json.remove() self.doc.remove() self.code.remove() + + +#---------------------------------------------------------- + + +class NumpyJSONEncoder(simplejson.JSONEncoder): + """Encodes numpy arrays and scalars + + See Also: + + :py:class:`simplejson.JSONEncoder` + + """ + + def default(self, obj): + if isinstance(obj, numpy.ndarray) or isinstance(obj, numpy.generic): + return obj.tolist() + elif isinstance(obj, numpy.dtype): + if obj.name == 'str': return 'string' + return obj.name + return simplejson.JSONEncoder.default(self, obj)