Commit 078cf0fd authored by Philip ABBET's avatar Philip ABBET

Refactoring of the 'CachedDataSink' class

parent 73fcb84a
......@@ -355,6 +355,7 @@ class Algorithm(object):
self.data = simplejson.load(f)
self.code_path = self.storage.code.path
self.code = self.storage.code.load()
self.groups = self.data['groups']
......@@ -772,3 +773,88 @@ class Algorithm(object):
raise #just re-raise the user exception
return Runner(self.__module, klass, self, exc)
def json_dumps(self, indent=4):
"""Dumps the JSON declaration of this object in a string
Parameters:
indent (int): The number of indentation spaces at every indentation level
Returns:
str: The JSON representation for this object
"""
return simplejson.dumps(self.data, indent=indent,
cls=utils.NumpyJSONEncoder)
def __str__(self):
return self.json_dumps()
def write(self, storage=None):
"""Writes contents to prefix location
Parameters:
storage (Storage, optional): If you pass a new storage, then this object
will be written to that storage point rather than its default.
"""
if self.data['language'] == 'unknown':
raise RuntimeError("algorithm has no programming language set")
if storage is None:
if not self._name:
raise RuntimeError("algorithm has no name")
storage = self.storage #overwrite
storage.save(str(self), self.code, self.description)
def export(self, prefix):
"""Recursively exports itself into another prefix
Dataformats and associated libraries are also copied.
Parameters:
prefix (str): A path to a prefix that must different then my own.
Returns:
None
Raises:
RuntimeError: If prefix and self.prefix point to the same directory.
"""
if not self._name:
raise RuntimeError("algorithm has no name")
if not self.valid:
raise RuntimeError("algorithm is not valid")
if os.path.samefile(prefix, self.prefix):
raise RuntimeError("Cannot export algorithm to the same prefix (%s == " \
"%s)" % (prefix, self.prefix))
for k in self.libraries.values():
k.export(prefix)
for k in self.dataformats.values():
k.export(prefix)
self.write(Storage(prefix, self.name, self.language))
This diff is collapsed.
......@@ -430,3 +430,81 @@ class DataFormat(object):
return self.isparent(other.referenced[other.extends])
return False
def json_dumps(self, indent=4):
"""Dumps the JSON declaration of this object in a string
Parameters:
indent (int): The number of indentation spaces at every indentation level
Returns:
str: The JSON representation for this object
"""
return simplejson.dumps(self.data, indent=indent,
cls=utils.NumpyJSONEncoder)
def __str__(self):
return self.json_dumps()
def write(self, storage=None):
"""Writes contents to prefix location
Parameters:
storage (Storage, optional): If you pass a new storage, then this object
will be written to that storage point rather than its default.
"""
if storage is None:
if not self._name:
raise RuntimeError("dataformat has no name")
storage = self.storage #overwrite
storage.save(str(self), self.description)
def export(self, prefix):
"""Recursively exports itself into another prefix
Other required dataformats are also copied.
Parameters:
prefix (str): A path to a prefix that must different then my own.
Returns:
None
Raises:
RuntimeError: If prefix and self.prefix point to the same directory.
"""
if not self._name:
raise RuntimeError("dataformat has no name")
if not self.valid:
raise RuntimeError("dataformat is not valid")
if os.path.samefile(prefix, self.prefix):
raise RuntimeError("Cannot dataformat object to the same prefix (%s " \
"== %s)" % (prefix, self.prefix))
for k in self.referenced.values():
k.export(prefix)
self.write(Storage(prefix, self.name))
......@@ -260,9 +260,8 @@ class DBExecutor(object):
group.add(inputs.Input(name, self.dataformat_cache[input_dataformat_name], data_source))
def process(self, zmq_context, zmq_socket):
self.handler = message_handler.MessageHandler(self.input_list, zmq_context, zmq_socket)
def process(self, address):
self.handler = message_handler.MessageHandler(address, inputs=self.input_list)
self.handler.start()
......@@ -275,6 +274,7 @@ class DBExecutor(object):
def wait(self):
self.handler.join()
self.handler.destroy()
self.handler = None
......
......@@ -136,16 +136,23 @@ class Executor(object):
if not self.input_list or not self.output_list:
raise RuntimeError("I/O for execution block has not yet been set up")
using_output = self.output_list[0] if self.analysis else self.output_list
while self.input_list.hasMoreData():
main_group = self.input_list.main_group
main_group.restricted_access = False
main_group.next()
main_group.restricted_access = True
if not self.runner.process(self.input_list, using_output):
if self.analysis:
result = self.runner.process(inputs=self.input_list, output=self.output_list[0])
else:
result = self.runner.process(inputs=self.input_list, outputs=self.output_list)
if not result:
return False
for output in self.output_list:
output.close()
missing_data_outputs = [x for x in self.output_list if x.isDataMissing()]
if missing_data_outputs:
......
......@@ -247,14 +247,30 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
if exception.errno != errno.EEXIST:
raise
if start_index is None:
for k, v in config['inputs'].items():
if v['channel'] == config['channel']:
input_path = os.path.join(cache_root, v['path'] + '.data')
break
(data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \
data.getAllFilenames(input_path)
end_indices = [ int(x.split('.')[-2]) for x in indices_filenames ]
end_indices.sort()
start_index = 0
end_index = end_indices[-1]
data_sink = data.CachedDataSink()
data_sinks.append(data_sink)
status = data_sink.setup(
filename=path,
dataformat=dataformat,
encoding='binary',
max_size=0, # in bytes, for individual file chunks
start_index=start_index,
end_index=end_index,
encoding='binary'
)
if not status:
......@@ -262,7 +278,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
output_list.add(outputs.Output(name, data_sink,
synchronization_listener=synchronization_listener,
force_start_index=start_index or 0)
force_start_index=start_index)
)
if 'result' not in config:
......
......@@ -292,3 +292,85 @@ class Library(object):
raise RuntimeError("library has no name")
return self.storage.hash()
def json_dumps(self, indent=4):
"""Dumps the JSON declaration of this object in a string
Parameters:
indent (int): The number of indentation spaces at every indentation level
Returns:
str: The JSON representation for this object
"""
return simplejson.dumps(self.data, indent=indent,
cls=utils.NumpyJSONEncoder)
def __str__(self):
return self.json_dumps()
def write(self, storage=None):
"""Writes contents to prefix location.
Parameters:
storage (Storage, optional): If you pass a new storage, then this object
will be written to that storage point rather than its default.
"""
if self.data['language'] == 'unknown':
raise RuntimeError("library has no programming language set")
if storage is None:
if not self._name:
raise RuntimeError("library has no name")
storage = self.storage #overwrite
storage.save(str(self), self.code, self.description)
def export(self, prefix):
"""Recursively exports itself into another prefix
Other required libraries are also copied.
Parameters:
prefix (str): A path to a prefix that must different then my own.
Returns:
None
Raises:
RuntimeError: If prefix and self.prefix point to the same directory.
"""
if not self._name:
raise RuntimeError("library has no name")
if not self.valid:
raise RuntimeError("library is not valid")
if os.path.samefile(prefix, self.prefix):
raise RuntimeError("Cannot library object to the same prefix (%s == " \
"%s)" % (prefix, self.prefix))
for k in self.libraries.values():
k.export(prefix)
self.write(Storage(prefix, self.name))
......@@ -39,13 +39,9 @@ from .inputs import RemoteException
class MessageHandler(threading.Thread):
'''A 0MQ message handler for our communication with other processes
'''A 0MQ message handler for our communication with other processes'''
Support for more messages can be implemented by subclassing this class.
This one only support input-related messages.
'''
def __init__(self, input_list, zmq_context, zmq_socket, kill_callback=None):
def __init__(self, host_address, inputs=None, outputs=None, kill_callback=None):
super(MessageHandler, self).__init__()
......@@ -56,14 +52,29 @@ class MessageHandler(threading.Thread):
self.must_kill = threading.Event()
self.must_kill.clear()
# Starts our 0MQ server
self.context = zmq_context
self.socket = zmq_socket
# Either starts a 0MQ server or connect to an existing one
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PAIR)
if not host_address.startswith('tcp://'):
self.address = 'tcp://' + host_address
else:
self.address = host_address
if len(self.address.split(':')) == 2:
port = self.socket.bind_to_random_port(self.address, min_port=50000)
self.address += ':%d' % port
logger.debug("zmq server bound to '%s'", self.address)
else:
self.socket.connect(self.address)
logger.debug("connected to '%s'", self.address)
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
self.input_list = input_list
# Initialisations
self.input_list = inputs
self.output_list = outputs
self.system_error = ''
self.user_error = ''
......@@ -73,15 +84,34 @@ class MessageHandler(threading.Thread):
# implementations
self.callbacks = dict(
nxt = self.next,
hmd = self.has_more_data,
don = self.done,
err = self.error,
)
if self.input_list is not None:
self.callbacks.update(dict(
nxt = self.next,
hmd = self.has_more_data,
))
def run(self):
if self.output_list is not None:
self.callbacks.update(dict(
wrt = self.write,
oic = self.output_is_connected,
))
def destroy(self):
self.socket.setsockopt(zmq.LINGER, 0)
self.socket.close()
self.context.destroy()
def __str__(self):
return 'MessageHandler(%s)' % self.address
def run(self):
logger.debug("0MQ server thread started")
while not self.stop.is_set(): #keep on
......@@ -161,7 +191,6 @@ class MessageHandler(threading.Thread):
def _get_input_candidate(self, channel, name):
channel_group = self.input_list.group(channel)
retval = channel_group[name]
if retval is None:
......@@ -170,11 +199,55 @@ class MessageHandler(threading.Thread):
return retval
def _get_output_candidate(self, name):
retval = self.output_list[name]
if retval is None:
raise RuntimeError("Could not find output `%s'" % name)
return retval
def _acknowledge(self):
logger.debug('send: ack')
self.socket.send('ack')
logger.debug('setting stop condition for 0MQ server thread')
self.stop.set()
def done(self, statistics=None):
"""Syntax: don"""
logger.debug('recv: don %s', statistics)
if statistics is not None:
self.statistics = simplejson.loads(statistics)
self._acknowledge()
def error(self, t, msg):
"""Syntax: err type message"""
logger.debug('recv: err %s <msg> (size=%d)', t, len(msg))
if t == 'usr':
self.user_error = msg
else:
self.system_error = msg
self.statistics = dict(network=dict(wait_time=0.))
self._acknowledge()
def next(self, channel, name):
"""Syntax: nxt channel name ..."""
logger.debug('recv: nxt %s %s', channel, name)
if self.input_list is None:
message = 'Unexpected message received: nxt %s %s' % (channel, name)
raise RemoteException('sys', message)
channel_group = self.input_list.group(channel)
restricted = channel_group.restricted_access
......@@ -191,7 +264,6 @@ class MessageHandler(threading.Thread):
"`%s' on input `%s', but it is over (no more data). This " \
"normally indicates a programming error on the user " \
"side." % (channel, name)
self.user_error += message + '\n'
raise RuntimeError(message)
if isinstance(input_candidate.data, baseformat.baseformat):
......@@ -211,6 +283,11 @@ class MessageHandler(threading.Thread):
"""Syntax: hmd channel name"""
logger.debug('recv: hmd %s %s', channel, name)
if self.input_list is None:
message = 'Unexpected message received: hmd %s %s' % (channel, name)
raise RemoteException('sys', message)
input_candidate = self._get_input_candidate(channel, name)
what = 'tru' if input_candidate.hasMoreData() else 'fal'
......@@ -218,37 +295,43 @@ class MessageHandler(threading.Thread):
self.socket.send(what)
def _acknowledge(self):
def write(self, name, end_data_index, packed):
"""Syntax: wrt output data"""
logger.debug('send: ack')
self.socket.send('ack')
logger.debug('setting stop condition for 0MQ server thread')
self.stop.set()
end_data_index = int(end_data_index)
logger.debug('recv: wrt %s %d <bin> (size=%d)', name, end_data_index, len(packed))
def done(self, statistics=None):
"""Syntax: don"""
if self.output_list is None:
message = 'Unexpected message received: wrt %s %d <bin> (size=%d)' % (name, end_data_index, len(packed))
raise RemoteException('sys', message)
logger.debug('recv: don %s', statistics)
# Get output object
output_candidate = self._get_output_candidate(name)
if output_candidate is None:
raise RuntimeError("Could not find output `%s' to write to" % name)
if statistics is not None:
self.statistics = simplejson.loads(statistics)
data = output_candidate.data_sink.dataformat.type()
data.unpack(packed)
output_candidate.write(data, end_data_index=end_data_index)
self._acknowledge()
logger.debug('send: ack')
self.socket.send('ack')
def error(self, t, msg):
"""Syntax: err type message"""
def output_is_connected(self, name):
"""Syntax: oic output"""
logger.debug('recv: err %s <msg> (size=%d)', t, len(msg))
logger.debug('recv: oic %s', name)
if t == 'usr':
self.user_error = msg
else:
self.system_error = msg
if self.output_list is None:
message = 'Unexpected message received: oic %s' % name
raise RemoteException('sys', message)
self.statistics = dict(network=dict(wait_time=0.))
self._acknowledge()
output_candidate = self._get_output_candidate(name)
what = 'tru' if output_candidate.isConnected() else 'fal'
logger.debug('send: %s', what)
self.socket.send(what)
def kill(self):
......
......@@ -144,6 +144,10 @@ class BaseOutput(object):
return end_data_index
def close(self):
pass
#----------------------------------------------------------
......@@ -233,6 +237,10 @@ class Output(BaseOutput):
return self.data_sink.isConnected()
def close(self):
self.data_sink.close()
#----------------------------------------------------------
......
......@@ -537,7 +537,8 @@ class TestExecutionBase(unittest.TestCase):
self.assertTrue(dataformat.valid)
data_sink = CachedDataSink()
self.assertTrue(data_sink.setup(self.filenames[input_name], dataformat))
self.assertTrue(data_sink.setup(self.filenames[input_name], dataformat,
indices[0][0], indices[-1][1]))
for i in indices:
data = dataformat.type()
......
......@@ -116,7 +116,7 @@ class TestCachedDataBase(unittest.TestCase):
self.assertTrue(dataformat.valid)
data_sink = CachedDataSink()
self.assertTrue(data_sink.setup(self.filename, dataformat))
self.assertTrue(data_sink.setup(self.filename, dataformat, start_index, end_index))
all_data = []
for i in range(start_index, end_index + 1):
......@@ -352,7 +352,7 @@ class TestDataSink(TestCachedDataBase):
self.assertTrue(dataformat.valid)
data_sink = CachedDataSink()
self.assertTrue(data_sink.setup(self.filename, dataformat))
self.assertTrue(data_sink.setup(self.filename, dataformat, 0, 10))
#----------------------------------------------------------
......
......@@ -69,7 +69,8 @@ class DataLoaderBaseTest(unittest.TestCase):
self.assertTrue(dataformat.valid)
data_sink = CachedDataSink()
self.assertTrue(data_sink.setup(self.filenames[input_name], dataformat))
self.assertTrue(data_sink.setup(self.filenames[input_name], dataformat,
indices[0][0], indices[-1][1]))
for i in indices:
data = dataformat.type()
......
......@@ -123,20 +123,18 @@ class HostSide(object):
class ContainerSide(object):
def __init__(self, zmq_context, address):
def __init__(self, address):
dataformat_cache = {}
database_cache = {}
self.dbexecutor = DBExecutor(prefix, CONFIGURATION,
dataformat_cache, database_cache)
assert self.dbexecutor.valid, '\n * %s' % '\n * '.join(self.dbexecutor.errors)
dataformat_cache, database_cache)
self.socket = zmq_context.socket(zmq.PAIR)
self.socket.connect(address)
assert self.dbexecutor.valid, '\n * %s' % '\n * '.join(self.dbexecutor.errors)
with self.dbexecutor:
self.dbexecutor.process(zmq_context, self.socket)
self.dbexecutor.process(address)
def wait(self):
......@@ -153,7 +151,7 @@ class TestExecution(unittest.TestCase):
context = zmq.Context()
host = HostSide(context)
container = ContainerSide(context, host.address)
container = ContainerSide(host.address)
while host.group.hasMoreData():
host.group.next()
......
......@@ -61,19 +61,12 @@ class TestMessageHandlerBase(unittest.TestCase):
self.input_list = InputList()
self.input_list.add(group)
self.server_context = zmq.Context()
server_socket = self.server_context.socket(zmq.PAIR)
address = 'tcp://127.0.0.1'
port = server_socket.bind_to_random_port(address, min_port=50000)
address += ':%d' % port
self.message_handler = MessageHandler(self.input_list, self.server_context, server_socket)
self.message_handler = MessageHandler('127.0.0.1', inputs=self.input_list)
self.client_context = zmq.Context()
client_socket = self.client_context.socket(zmq.PAIR)
client_socket.connect(address)
client_socket.connect(self.message_handler.address)