Skip to content
Snippets Groups Projects
Commit 996e6904 authored by Samuel GAIST's avatar Samuel GAIST
Browse files

[loop_executor] Ran black

parent 7767894f
No related branches found
No related tags found
2 merge requests!281.6.x,!27Soft loop
...@@ -26,13 +26,13 @@ ...@@ -26,13 +26,13 @@
############################################################################### ###############################################################################
''' """
======== ========
executor executor
======== ========
A class that can setup and execute loop algorithm blocks on the backend A class that can setup and execute loop algorithm blocks on the backend
''' """
import logging import logging
import os import os
...@@ -55,7 +55,7 @@ logger = logging.getLogger(__name__) ...@@ -55,7 +55,7 @@ logger = logging.getLogger(__name__)
def make_data_format(data, data_format): def make_data_format(data, data_format):
if isinstance(data, dict): if isinstance(data, dict):
d = data_format.type() d = data_format.type()
d.from_dict(data, casting='safe', add_defaults=False) d.from_dict(data, casting="safe", add_defaults=False)
data = d data = d
else: else:
raise TypeError("Can't transform {} in dataformat".format(data.__class__)) raise TypeError("Can't transform {} in dataformat".format(data.__class__))
...@@ -64,27 +64,25 @@ def make_data_format(data, data_format): ...@@ -64,27 +64,25 @@ def make_data_format(data, data_format):
class LoopChannel(object): class LoopChannel(object):
def __init__(self, socket): def __init__(self, socket):
self.socket = socket self.socket = socket
def setup(self, configuration, algorithm, prefix): def setup(self, configuration, algorithm, prefix):
input_format_name = algorithm.loop_map['input'] input_format_name = algorithm.loop_map["input"]
self.input_data_format = DataFormat(prefix, input_format_name) self.input_data_format = DataFormat(prefix, input_format_name)
ouput_format_name = algorithm.loop_map['output'] ouput_format_name = algorithm.loop_map["output"]
self.output_data_format = DataFormat(prefix, ouput_format_name) self.output_data_format = DataFormat(prefix, ouput_format_name)
def is_result_valid(self, result): def is_result_valid(self, result):
data = make_data_format(result, self.output_data_format) data = make_data_format(result, self.output_data_format)
self.socket.send_string('val', zmq.SNDMORE) self.socket.send_string("val", zmq.SNDMORE)
self.socket.send(data.pack()) self.socket.send(data.pack())
answer = self.socket.recv() answer = self.socket.recv()
if answer == b'err': if answer == b"err":
kind = self.socket.recv() kind = self.socket.recv()
message = self.socket.recv() message = self.socket.recv()
raise RemoteException(kind, message) raise RemoteException(kind, message)
...@@ -94,34 +92,35 @@ class LoopChannel(object): ...@@ -94,34 +92,35 @@ class LoopChannel(object):
data = self.input_data_format.type() data = self.input_data_format.type()
retval = data.unpack(packed) retval = data.unpack(packed)
return answer == 'True', retval return answer == "True", retval
class LoopMessageHandler(MessageHandler): class LoopMessageHandler(MessageHandler):
def __init__(
self, host_address, data_sources=None, kill_callback=None, context=None
):
super(LoopMessageHandler, self).__init__(
host_address, data_sources, kill_callback, context
)
def __init__(self, host_address, data_sources=None, kill_callback=None, context=None): self.callbacks.update({"val": self.validate})
super(LoopMessageHandler, self).__init__(host_address, data_sources, kill_callback, context)
self.callbacks.update({'val': self.validate })
self.executor = None self.executor = None
def setup(self, configuration, algorithm, prefix): def setup(self, configuration, algorithm, prefix):
input_format_name = algorithm.loop_map['input'] input_format_name = algorithm.loop_map["input"]
self.input_data_format = DataFormat(prefix, input_format_name) self.input_data_format = DataFormat(prefix, input_format_name)
ouput_format_name = algorithm.loop_map['output'] ouput_format_name = algorithm.loop_map["output"]
self.output_data_format = DataFormat(prefix, ouput_format_name) self.output_data_format = DataFormat(prefix, ouput_format_name)
def set_executor(self, executor): def set_executor(self, executor):
self.executor = executor self.executor = executor
def validate(self, result): def validate(self, result):
"""Syntax: val""" """Syntax: val"""
result = result.encode('utf-8') result = result.encode("utf-8")
logger.debug('recv: val %s', result) logger.debug("recv: val %s", result)
data = self.input_data_format.type() data = self.input_data_format.type()
data.unpack(result) data.unpack(result)
...@@ -166,9 +165,16 @@ class Executor(object): ...@@ -166,9 +165,16 @@ class Executor(object):
guarantee that the cache is refreshed as appropriate in case the guarantee that the cache is refreshed as appropriate in case the
underlying libraries change. """ underlying libraries change. """
def __init__(self, message_handler, directory, dataformat_cache=None, def __init__(
database_cache=None, library_cache=None, cache_root='/cache', self,
db_socket=None): message_handler,
directory,
dataformat_cache=None,
database_cache=None,
library_cache=None,
cache_root="/cache",
db_socket=None,
):
self._runner = None self._runner = None
self.algorithm = None self.algorithm = None
...@@ -176,12 +182,12 @@ class Executor(object): ...@@ -176,12 +182,12 @@ class Executor(object):
self.db_socket = db_socket self.db_socket = db_socket
self.configuration = os.path.join(directory, 'configuration.json') self.configuration = os.path.join(directory, "configuration.json")
with open(self.configuration, 'rb') as f: with open(self.configuration, "rb") as f:
conf_data = f.read().decode('utf-8') conf_data = f.read().decode("utf-8")
self.data = json.loads(conf_data) self.data = json.loads(conf_data)
self.prefix = os.path.join(directory, 'prefix') self.prefix = os.path.join(directory, "prefix")
# Temporary caches, if the user has not set them, for performance # Temporary caches, if the user has not set them, for performance
database_cache = database_cache if database_cache is not None else {} database_cache = database_cache if database_cache is not None else {}
...@@ -189,20 +195,24 @@ class Executor(object): ...@@ -189,20 +195,24 @@ class Executor(object):
library_cache = library_cache if library_cache is not None else {} library_cache = library_cache if library_cache is not None else {}
# Load the algorithm # Load the algorithm
self.algorithm = Algorithm(self.prefix, self.data['algorithm'], self.algorithm = Algorithm(
dataformat_cache, library_cache) self.prefix, self.data["algorithm"], dataformat_cache, library_cache
)
self.input_list, self.data_loaders = create_inputs_from_configuration( self.input_list, self.data_loaders = create_inputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root, self.data,
cache_access=AccessMode.LOCAL, db_access=AccessMode.REMOTE, self.algorithm,
socket=self.db_socket self.prefix,
cache_root,
cache_access=AccessMode.LOCAL,
db_access=AccessMode.REMOTE,
socket=self.db_socket,
) )
self.message_handler = message_handler self.message_handler = message_handler
self.message_handler.setup(self.data, self.algorithm, self.prefix) self.message_handler.setup(self.data, self.algorithm, self.prefix)
self.message_handler.set_executor(self) self.message_handler.set_executor(self)
@property @property
def runner(self): def runner(self):
"""Returns the algorithm runner """Returns the algorithm runner
...@@ -214,15 +224,13 @@ class Executor(object): ...@@ -214,15 +224,13 @@ class Executor(object):
self._runner = self.algorithm.runner() self._runner = self.algorithm.runner()
return self._runner return self._runner
def setup(self): def setup(self):
"""Sets up the algorithm to start processing""" """Sets up the algorithm to start processing"""
retval = self.runner.setup(self.data['parameters']) retval = self.runner.setup(self.data["parameters"])
logger.debug("User loop is setup: {}".format(retval)) logger.debug("User loop is setup: {}".format(retval))
return retval return retval
def prepare(self): def prepare(self):
"""Prepare the algorithm""" """Prepare the algorithm"""
...@@ -230,14 +238,12 @@ class Executor(object): ...@@ -230,14 +238,12 @@ class Executor(object):
logger.debug("User loop is prepared: {}".format(retval)) logger.debug("User loop is prepared: {}".format(retval))
return retval return retval
def process(self): def process(self):
"""Executes the user algorithm code using the current interpreter. """Executes the user algorithm code using the current interpreter.
""" """
self.message_handler.start() self.message_handler.start()
def validate(self, result): def validate(self, result):
"""Executes the loop validation code""" """Executes the loop validation code"""
...@@ -245,7 +251,6 @@ class Executor(object): ...@@ -245,7 +251,6 @@ class Executor(object):
logger.debug("User loop has validated: {}".format(is_valid, retval)) logger.debug("User loop has validated: {}".format(is_valid, retval))
return is_valid, retval return is_valid, retval
@property @property
def address(self): def address(self):
""" Address of the message handler""" """ Address of the message handler"""
...@@ -257,7 +262,6 @@ class Executor(object): ...@@ -257,7 +262,6 @@ class Executor(object):
"""A boolean that indicates if this executor is valid or not""" """A boolean that indicates if this executor is valid or not"""
return not bool(self.errors) return not bool(self.errors)
def wait(self): def wait(self):
"""Wait for the message handle to finish""" """Wait for the message handle to finish"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment