Commit 89f83797 authored by André Anjos's avatar André Anjos 💬
Browse files

Merge branch 'implement_loop_output' into 'master'

Implement loop output

See merge request !54
parents 3e146f57 4231c55a
Pipeline #32107 passed with stages
in 7 minutes and 19 seconds
......@@ -325,6 +325,22 @@ class Runner(object):
return answer
def write(self, outputs, end_data_index):
"""Write to the outputs"""
exc = self.exc or RuntimeError
if self.algorithm.type != Algorithm.LOOP:
raise exc("Wrong algorithm type: %s" % self.algorithm.type)
# setup() must have run
if not self.ready:
raise exc("Algorithm '%s' is not yet setup" % self.name)
# prepare() must have run
if not self.prepared:
raise exc("Algorithm '%s' is not yet prepared" % self.name)
return loader.run(self.obj, "write", self.exc, outputs, end_data_index)
def __getattr__(self, key):
"""Returns an attribute of the algorithm - only called at last resort
"""
......@@ -415,6 +431,8 @@ class Algorithm(object):
SEQUENTIAL_LOOP_USER = "sequential_loop_user"
AUTONOMOUS_LOOP_USER = "autonomous_loop_user"
dataformat_klass = dataformat.DataFormat
def __init__(self, prefix, name, dataformat_cache=None, library_cache=None):
self._name = None
......@@ -475,71 +493,47 @@ class Algorithm(object):
self._convert_parameter_types()
self._load_libraries(library_cache)
def _update_dataformat_cache(self, type_name, dataformat_cache):
"""Update the data format cache based on the type name"""
if type_name not in self.dataformats:
if dataformat_cache and type_name in dataformat_cache: # reuse
thisformat = dataformat_cache[type_name]
else: # load it
thisformat = self.dataformat_klass(self.prefix, type_name)
if dataformat_cache is not None: # update it
dataformat_cache[type_name] = thisformat
self.dataformats[type_name] = thisformat
return self.dataformats[type_name]
def _update_dataformat_cache_for_group(self, group, dataformat_cache):
for _, entry in group.items():
self._update_dataformat_cache(entry["type"], dataformat_cache)
def _load_dataformats(self, dataformat_cache):
"""Makes sure we can load all requested formats
"""
for group in self.groups:
for name, input in group["inputs"].items():
if input["type"] in self.dataformats:
continue
if dataformat_cache and input["type"] in dataformat_cache: # reuse
thisformat = dataformat_cache[input["type"]]
else: # load it
thisformat = dataformat.DataFormat(self.prefix, input["type"])
if dataformat_cache is not None: # update it
dataformat_cache[input["type"]] = thisformat
self.dataformats[input["type"]] = thisformat
self._update_dataformat_cache_for_group(group["inputs"], dataformat_cache)
if "outputs" in group:
for name, output in group["outputs"].items():
if output["type"] in self.dataformats:
continue
if dataformat_cache and output["type"] in dataformat_cache: # reuse
thisformat = dataformat_cache[output["type"]]
else: # load it
thisformat = dataformat.DataFormat(self.prefix, output["type"])
if dataformat_cache is not None: # update it
dataformat_cache[output["type"]] = thisformat
self.dataformats[output["type"]] = thisformat
self._update_dataformat_cache_for_group(
group["outputs"], dataformat_cache
)
if "loop" in group:
for name, entry in group["loop"].items():
entry_format = entry["type"]
if entry_format in self.dataformats:
continue
if dataformat_cache and entry_format in dataformat_cache:
thisformat = dataformat_cache[entry_format]
else:
thisformat = dataformat.DataFormat(self.prefix, entry_format)
if dataformat_cache is not None:
dataformat_cache[entry_format] = thisformat
self.dataformats[entry_format] = thisformat
self._update_dataformat_cache_for_group(group["loop"], dataformat_cache)
if self.results:
for name, result in self.results.items():
if result["type"].find("/") != -1:
if result["type"] in self.dataformats:
continue
if dataformat_cache and result["type"] in dataformat_cache: # reuse
thisformat = dataformat_cache[result["type"]]
else:
thisformat = dataformat.DataFormat(self.prefix, result["type"])
if dataformat_cache is not None: # update it
dataformat_cache[result["type"]] = thisformat
self.dataformats[result["type"]] = thisformat
result_type = result["type"]
# results can only contain base types and plots therefore, only
# process plots
if result_type.find("/") != -1:
self._update_dataformat_cache(result_type, dataformat_cache)
def _convert_parameter_types(self):
"""Converts types to numpy equivalents, checks defaults, ranges and choices
......
......@@ -157,7 +157,7 @@ def getAllFilenames(filename, start_index=None, end_index=None):
# Retrieve all the related files
basename, ext = os.path.splitext(filename)
filenames = sorted(glob.glob(basename + "*"), key=file_start)
filenames = sorted(glob.glob(basename + ".*"), key=file_start)
# (If necessary) Only keep files containing the desired indices
if (start_index is not None) or (end_index is not None):
......@@ -399,7 +399,9 @@ class CachedDataSource(DataSource):
)
if not ok_indices:
raise IOError("data file `%s' have missing indices." % f_data)
raise IOError(
"data file `%s|%s' has missing indices." % (f_data, f_chck)
)
self.prefix = prefix
self.unpack = unpack
......@@ -413,6 +415,7 @@ class CachedDataSource(DataSource):
) = getAllFilenames(filename, start_index, end_index)
if len(self.filenames) == 0:
logger.warn("No files found for %s" % filename)
return False
check_consistency(self.filenames, data_checksum_filenames)
......
......@@ -174,6 +174,7 @@ class AlgorithmExecutor(object):
cache_root,
input_list=self.input_list,
data_loaders=self.data_loaders,
loop_socket=self.loop_socket,
)
if self.loop_socket:
......@@ -255,14 +256,18 @@ class AlgorithmExecutor(object):
output=self.output_list[0],
)
else:
result = self.runner.process(
inputs=self.input_list,
data_loaders=self.data_loaders,
outputs=self.output_list,
loop_channel=self.loop_channel,
)
try:
result = self.runner.process(
inputs=self.input_list,
data_loaders=self.data_loaders,
outputs=self.output_list,
loop_channel=self.loop_channel,
)
except Exception:
result = None
if not result:
self.done({})
return False
for output in self.output_list:
......
......@@ -48,7 +48,6 @@ import json
import zmq
from ..algorithm import Algorithm
from ..baseformat import baseformat
from ..dataformat import DataFormat
from ..helpers import create_inputs_from_configuration
from ..helpers import create_outputs_from_configuration
......@@ -209,6 +208,15 @@ class LoopExecutor(object):
databases=databases,
)
self.output_list, _ = create_outputs_from_configuration(
self.data,
self.algorithm,
self.prefix,
cache_root,
input_list=self.input_list,
data_loaders=self.data_loaders,
)
self.message_handler = message_handler
self.message_handler.setup(self.algorithm, self.prefix)
self.message_handler.set_executor(self)
......@@ -251,6 +259,13 @@ class LoopExecutor(object):
logger.debug("User loop has validated: {}\n{}".format(is_valid, answer))
return is_valid, answer
def write(self, end_data_index=None):
"""Write the loop output"""
retval = self.runner.write(self.output_list, end_data_index)
logger.debug("User loop wrote output: {}".format(retval))
return retval
@property
def address(self):
""" Address of the message handler"""
......@@ -267,3 +282,9 @@ class LoopExecutor(object):
self.message_handler.join()
self.message_handler = None
def close(self):
"""Close all outputs"""
for output in self.output_list:
output.close()
......@@ -47,7 +47,6 @@ import logging
import zmq
import simplejson
import requests
import threading
from ..dataformat import DataFormat
......@@ -162,7 +161,7 @@ class MessageHandler(threading.Thread):
self.kill_callback()
self.stop.set()
break
except RuntimeError as e:
except RuntimeError:
import traceback
message = traceback.format_exc()
......@@ -173,10 +172,12 @@ class MessageHandler(threading.Thread):
self.kill_callback()
self.stop.set()
break
except:
except Exception:
import traceback
parser = lambda s: s if len(s) < 20 else s[:20] + "..."
def parser(s):
return s if len(s) < 20 else s[:20] + "..."
parsed_parts = " ".join([parser(k) for k in parts])
message = (
"A problem occurred while performing command `%s' "
......@@ -249,7 +250,7 @@ class MessageHandler(threading.Thread):
try:
data_source = self.data_sources[name]
except:
except Exception:
raise RemoteException("sys", "Unknown input: %s" % name)
logger.debug("send: %d infos", len(data_source))
......@@ -275,12 +276,12 @@ class MessageHandler(threading.Thread):
try:
data_source = self.data_sources[name]
except:
except Exception:
raise RemoteException("sys", "Unknown input: %s" % name)
try:
index = int(index)
except:
except Exception:
raise RemoteException("sys", "Invalid index: %s" % index)
(data, start_index, end_index) = data_source[index]
......@@ -351,6 +352,7 @@ class LoopMessageHandler(MessageHandler):
)
self.callbacks.update({"val": self.validate})
self.callbacks.update({"wrt": self.write})
self.executor = None
def setup(self, algorithm, prefix):
......@@ -401,3 +403,20 @@ class LoopMessageHandler(MessageHandler):
self.socket.send_string("True" if is_valid else "False", zmq.SNDMORE)
self.socket.send(data.pack())
def write(self, end_data_index):
""" Trigger a write on the output"""
try:
end_data_index = int(end_data_index)
except ValueError:
logger.warning("recv: wrt invalid value {}".format(end_data_index))
end_data_index = None
logger.debug("recv: wrt {}".format(end_data_index))
try:
self.executor.write(end_data_index)
except Exception:
raise
finally:
self.socket.send_string("ack")
......@@ -58,6 +58,7 @@ from .inputs import InputGroup
from .outputs import SynchronizationListener
from .outputs import OutputList
from .outputs import Output
from .outputs import RemotelySyncedOutput
from .algorithm import Algorithm
logger = logging.getLogger(__name__)
......@@ -81,6 +82,12 @@ def parse_inputs(inputs):
return data
def parse_outputs(outputs):
return dict(
[(k, {"channel": v["channel"], "path": v["path"]}) for k, v in outputs.items()]
)
def convert_loop_to_container(config):
data = {
"algorithm": config["algorithm"],
......@@ -90,7 +97,7 @@ def convert_loop_to_container(config):
}
data["inputs"] = parse_inputs(config["inputs"])
data["outputs"] = parse_outputs(config["outputs"])
return data
......@@ -108,12 +115,7 @@ def convert_experiment_configuration_to_container(config):
data["inputs"] = parse_inputs(config["inputs"])
if "outputs" in config:
data["outputs"] = dict(
[
(k, {"channel": v["channel"], "path": v["path"]})
for k, v in config["outputs"].items()
]
)
data["outputs"] = parse_outputs(config["outputs"])
else:
data["result"] = {
"channel": config["channel"],
......@@ -379,7 +381,13 @@ def create_inputs_from_configuration(
def create_outputs_from_configuration(
config, algorithm, prefix, cache_root, input_list=None, data_loaders=None
config,
algorithm,
prefix,
cache_root,
input_list=None,
data_loaders=None,
loop_socket=None,
):
data_sinks = []
......@@ -474,14 +482,25 @@ def create_outputs_from_configuration(
if not status:
raise IOError("Cannot create cache sink '%s'" % details["path"])
output_list.add(
Output(
name,
data_sink,
synchronization_listener=synchronization_listener,
force_start_index=start_index,
if loop_socket is not None:
output_list.add(
RemotelySyncedOutput(
name,
data_sink,
loop_socket,
synchronization_listener=synchronization_listener,
force_start_index=start_index,
)
)
else:
output_list.add(
Output(
name,
data_sink,
synchronization_listener=synchronization_listener,
force_start_index=start_index,
)
)
)
if "result" not in config:
logger.debug(
......
......@@ -45,6 +45,8 @@ This module implements output related classes
import six
import logging
import zmq
logger = logging.getLogger(__name__)
......@@ -101,16 +103,16 @@ class Output(object):
"""
def __init__(self, name, data_sink, synchronization_listener=None,
force_start_index=0):
def __init__(
self, name, data_sink, synchronization_listener=None, force_start_index=0
):
self.name = str(name)
self.last_written_data_index = force_start_index - 1
self.nb_data_blocks_written = 0
self.data_sink = data_sink
self.name = str(name)
self.last_written_data_index = force_start_index - 1
self.nb_data_blocks_written = 0
self.data_sink = data_sink
self._synchronization_listener = synchronization_listener
def _createData(self):
"""Retrieves an uninitialized block of data corresponding to the data
format of the output
......@@ -118,13 +120,14 @@ class Output(object):
This method must be called to correctly create a new block of data
"""
if hasattr(self.data_sink, 'dataformat'):
if hasattr(self.data_sink, "dataformat"):
return self.data_sink.dataformat.type()
else:
raise RuntimeError("The currently used data sink is not bound to " \
"a dataformat - you cannot create uninitialized data under " \
"these circumstances")
raise RuntimeError(
"The currently used data sink is not bound to "
"a dataformat - you cannot create uninitialized data under "
"these circumstances"
)
def write(self, data, end_data_index=None):
"""Write a block of data on the output
......@@ -145,7 +148,7 @@ class Output(object):
# if the user passes a dictionary, converts to the proper baseformat type
if isinstance(data, dict):
d = self.data_sink.dataformat.type()
d.from_dict(data, casting='safe', add_defaults=False)
d.from_dict(data, casting="safe", add_defaults=False)
data = d
self.data_sink.write(data, self.last_written_data_index + 1, end_data_index)
......@@ -153,27 +156,29 @@ class Output(object):
self.last_written_data_index = end_data_index
self.nb_data_blocks_written += 1
def isDataMissing(self):
"""Returns whether data are missing"""
return (self._synchronization_listener is not None) and \
(self._synchronization_listener.data_index_end != self.last_written_data_index)
return (self._synchronization_listener is not None) and (
self._synchronization_listener.data_index_end
!= self.last_written_data_index
)
def isConnected(self):
"""Returns whether the associated data sink is connected"""
return self.data_sink.isConnected()
def _compute_end_data_index(self, end_data_index):
if end_data_index is not None:
if (end_data_index < self.last_written_data_index + 1) or \
((self._synchronization_listener is not None) and \
(end_data_index > self._synchronization_listener.data_index_end)):
raise KeyError("Algorithm logic error on write(): `end_data_index' " \
"is not consistent with last written index")
if (end_data_index < self.last_written_data_index + 1) or (
(self._synchronization_listener is not None)
and (end_data_index > self._synchronization_listener.data_index_end)
):
raise KeyError(
"Algorithm logic error on write(): `end_data_index' "
"is not consistent with last written index"
)
elif self._synchronization_listener is not None:
end_data_index = self._synchronization_listener.data_index_end
......@@ -183,14 +188,37 @@ class Output(object):
return end_data_index
def close(self):
"""Closes the associated data sink"""
self.data_sink.close()
#----------------------------------------------------------
# ----------------------------------------------------------
class RemotelySyncedOutput(Output):
def __init__(
self,
name,
data_sink,
socket,
synchronization_listener=None,
force_start_index=0,
):
super(RemotelySyncedOutput, self).__init__(
name, data_sink, synchronization_listener, force_start_index
)
self.socket = socket
def write(self, data, end_data_index=None):
super(RemotelySyncedOutput, self).write(data, end_data_index)
self.socket.send_string("wrt", zmq.SNDMORE)
self.socket.send_string("{}".format(end_data_index))
self.socket.recv()
# ----------------------------------------------------------
class OutputList:
......@@ -223,7 +251,6 @@ class OutputList:
def __init__(self):
self._outputs = []
def __getitem__(self, index):
if isinstance(index, six.string_types):
......@@ -232,18 +259,17 @@ class OutputList:
except IndexError:
pass
elif isinstance(index, int):
if index < len(self._outputs): return self._outputs[index]
if index < len(self._outputs):
return self._outputs[index]
return None
def __iter__(self):
for k in self._outputs: yield k
for k in self._outputs:
yield k
def __len__(self):
return len(self._outputs)
def add(self, output):
"""Adds an output to the list
......
......@@ -37,7 +37,7 @@
"""Executes a single algorithm. (%(version)s)
usage:
%(prog)s [--debug] [--cache=<path>] <addr> <dir> [<db_addr>] [<loop_addr>]
%(prog)s [--debug] [--cache=<path>] [--loop=<loop_addr>] <addr> <dir> [<db_addr>]
%(prog)s (--help)
%(prog)s (--version)
......@@ -47,14 +47,13 @@ arguments:
<dir> Directory containing all configuration required to run the user
algorithm
<db_addr> Address for databases-related I/O requests
<loop_addr> Address for loop-related I/O requests
options:
-h, --help Shows this help message and exit
-V, --version Shows program's version number and exit
-d, --debug Runs executor in debugging mode
-c, --cache=<path> Cache prefix, otherwise defaults to '/cache'
-c, --cache=path Cache prefix [default: /cache].
--loop=loop_addr Address for loop-related I/O requests
"""
......@@ -63,10 +62,8 @@ import logging
import os
import sys