Commit ca1dc069 authored by Philip ABBET's avatar Philip ABBET
Browse files

Refactoring: Move some classes and functions into beat.backend.python

parent c023dfa4
......@@ -44,251 +44,7 @@ from . import utils
from . import dock
from . import baseformat
class MessageHandler(gevent.Greenlet):
'''A 0MQ message handler for our communication with other processes
Support for more messages can be implemented by subclassing this class.
This one only support input-related messages.
'''
def __init__(self, input_list, zmq_context, zmq_socket):
super(MessageHandler, self).__init__()
# An event unblocking a graceful stop
self.stop = gevent.event.Event()
self.stop.clear()
self.must_kill = gevent.event.Event()
self.must_kill.clear()
# Starts our 0MQ server
self.context = zmq_context
self.socket = zmq_socket
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
self.input_list = input_list
self.system_error = ''
self.user_error = ''
self.last_statistics = {}
self.process = None
# implementations
self.callbacks = dict(
nxt = self.next,
hmd = self.has_more_data,
idd = self.is_dataunit_done,
don = self.done,
err = self.error,
)
def set_process(self, process):
self.process = process
self.process.statistics() # initialize internal statistics
def _run(self):
logger.debug("0MQ server thread started")
while not self.stop.is_set(): #keep on
if self.must_kill.is_set():
if self.process is not None:
self.process.kill()
self.must_kill.clear()
timeout = 1000 #ms
socks = dict(self.poller.poll(timeout)) #yields to the next greenlet
if self.socket in socks and socks[self.socket] == zmq.POLLIN:
# incomming
more = True
parts = []
while more:
parts.append(self.socket.recv())
more = self.socket.getsockopt(zmq.RCVMORE)
command = parts[0]
logger.debug("recv: %s", command)
if command in self.callbacks:
try: #to handle command
self.callbacks[command](*parts[1:])
except:
import traceback
parser = lambda s: s if len(s)<20 else s[:20] + '...'
parsed_parts = ' '.join([parser(k) for k in parts])
message = "A problem occurred while performing command `%s' " \
"killing user process. Exception:\n %s" % \
(parsed_parts, traceback.format_exc())
logger.error(message, exc_info=True)
self.system_error = message
if self.process is not None:
self.process.kill()
self.stop.set()
break
else:
message = "Command `%s' is not implemented - stopping user process" \
% command
logger.error(message)
self.system_error = message
if self.process is not None:
self.process.kill()
self.stop.set()
break
self.socket.setsockopt(zmq.LINGER, 0)
self.socket.close()
logger.debug("0MQ server thread stopped")
def _get_input_candidate(self, channel, name):
channel_group = self.input_list.group(channel)
retval = channel_group[name]
if retval is None:
raise RuntimeError("Could not find input `%s' at channel `%s'" % \
(name, channel))
return retval
def next(self, channel, name=None):
"""Syntax: nxt channel [name] ..."""
if name is not None: #single input
logger.debug('recv: nxt %s %s', channel, name)
input_candidate = self._get_input_candidate(channel, name)
input_candidate.next()
if input_candidate.data is None: #error
message = "User algorithm asked for more data for channel " \
"`%s' on input `%s', but it is over (no more data). This " \
"normally indicates a programming error on the user " \
"side." % (channel, name)
self.user_error += message + '\n'
raise RuntimeError(message)
if isinstance(input_candidate.data, baseformat.baseformat):
packed = input_candidate.data.pack()
else:
packed = input_candidate.data
logger.debug('send: <bin> (size=%d)', len(packed))
self.socket.send(packed)
else: #whole group data
logger.debug('recv: nxt %s', channel)
channel_group = self.input_list.group(channel)
# Call next() on the group
channel_group.restricted_access = False
channel_group.next()
channel_group.restricted_access = True
# Loop over the inputs
inputs_to_go = len(channel_group)
self.socket.send(str(inputs_to_go), zmq.SNDMORE)
for inp in channel_group:
logger.debug('send: %s', inp.name)
self.socket.send(str(inp.name), zmq.SNDMORE)
if inp.data is None:
message = "User algorithm process asked for more data on channel " \
"`%s' (all inputs), but input `%s' has nothing. This " \
"normally indicates a programming error on the user " \
"side." % (channel, inp.name)
self.user_error += message + '\n'
raise RuntimeError(message)
elif isinstance(inp.data, baseformat.baseformat):
packed = inp.data.pack()
else:
packed = inp.data
logger.debug('send: <bin> (size=%d)', len(packed))
inputs_to_go -= 1
if inputs_to_go > 0:
self.socket.send(packed, zmq.SNDMORE)
else:
self.socket.send(packed)
def has_more_data(self, channel, name=None):
"""Syntax: hmd channel [name]"""
if name: #single input
logger.debug('recv: hmd %s %s', channel, name)
input_candidate = self._get_input_candidate(channel, name)
what = 'tru' if input_candidate.hasMoreData() else 'fal'
else: #for all channel names
logger.debug('recv: hmd %s', channel)
channel_group = self.input_list.group(channel)
what = 'tru' if channel_group.hasMoreData() else 'fal'
logger.debug('send: %s', what)
self.socket.send(what)
def is_dataunit_done(self, channel, name):
"""Syntax: idd channel name"""
logger.debug('recv: idd %s %s', channel, name)
input_candidate = self._get_input_candidate(channel, name)
what = 'tru' if input_candidate.isDataUnitDone() else 'fal'
logger.debug('send: %s', what)
self.socket.send(what)
def _collect_statistics(self):
logger.debug('collecting user process statistics...')
if self.process is not None:
self.last_statistics = self.process.statistics()
def _acknowledge(self):
logger.debug('send: ack')
self.socket.send('ack')
logger.debug('setting stop condition for 0MQ server thread')
self.stop.set()
def done(self, wait_time=None):
"""Syntax: don"""
logger.debug('recv: don %s', wait_time)
if wait_time is not None:
self._collect_statistics()
# collect I/O stats from client
wait_time = float(wait_time)
self.last_statistics['data'] = dict(network=dict(wait_time=wait_time))
self._acknowledge()
def error(self, t, msg):
"""Syntax: err type message"""
logger.debug('recv: err %s <msg> (size=%d)', t, len(msg))
if t == 'usr': self.user_error = msg
else: self.system_error = msg
self._collect_statistics()
self.last_statistics['data'] = dict(network=dict(wait_time=0.))
self._acknowledge()
def kill(self):
self.must_kill.set()
from beat.backend.python.message_handler import MessageHandler
class Server(MessageHandler):
......@@ -398,7 +154,9 @@ class Agent(object):
self.virtual_memory_in_megabytes = virtual_memory_in_megabytes
self.max_cpu_percent = max_cpu_percent
self.tempdir = None
self.db_tempdir = None
self.process = None
self.db_process = None
self.server = None
......@@ -410,7 +168,10 @@ class Agent(object):
# Creates a temporary directory for the user process
self.tempdir = utils.temporary_directory()
logger.debug("Created temporary directory `%s'", self.tempdir)
self.db_tempdir = utils.temporary_directory()
logger.debug("Created temporary directory `%s'", self.db_tempdir)
self.process = None
self.db_process = None
return self
......@@ -420,11 +181,16 @@ class Agent(object):
shutil.rmtree(self.tempdir)
self.tempdir = None
if self.db_tempdir is not None and os.path.exists(self.db_tempdir):
shutil.rmtree(self.db_tempdir)
self.db_tempdir = None
self.process = None
self.db_process = None
logger.debug("Exiting processing context...")
def run(self, configuration, host, timeout_in_minutes=0, daemon=0):
def run(self, configuration, host, timeout_in_minutes=0, daemon=0, db_address=None):
"""Runs the algorithm code
......@@ -452,6 +218,9 @@ class Agent(object):
# Recursively copies configuration data to <tempdir>/prefix
configuration.dump_runner_configuration(self.tempdir)
if db_address is not None:
configuration.dump_databases_provider_configuration(self.db_tempdir)
# Server for our single client
self.server = Server(configuration.input_list, configuration.output_list,
host.ip)
......@@ -475,14 +244,27 @@ class Agent(object):
cmd = ['sleep', str(daemon)]
logger.debug("Daemon mode: sleeping for %d seconds", daemon)
else:
if db_address is not None:
tmp_dir = os.path.join('/tmp', os.path.basename(self.db_tempdir))
# db_cmd = ['bash', '-c', 'source activate beat_env; databases_provider %s %s' % (db_address, tmp_dir)]
db_cmd = ['bash', '-c', 'databases_provider %s %s' % (db_address, tmp_dir)]
self.db_process = dock.Popen(
host,
'docker.idiap.ch/beat/beat.env.db.examples:1.0.0',
# 'docker.idiap.ch/beat/beat.env.db:1.0.0',
command=db_cmd,
tmp_archive=self.db_tempdir,
)
self.process = dock.Popen(
host,
envkey,
command=cmd,
tmp_archive=self.tempdir,
virtual_memory_in_megabytes=self.virtual_memory_in_megabytes,
max_cpu_percent=self.max_cpu_percent,
)
host,
envkey,
command=cmd,
tmp_archive=self.tempdir,
virtual_memory_in_megabytes=self.virtual_memory_in_megabytes,
max_cpu_percent=self.max_cpu_percent,
)
# provide a tip on how to stop the test
if daemon > 0:
......@@ -504,6 +286,11 @@ class Agent(object):
timeout_in_minutes)
self.process.kill()
status = self.process.wait()
if self.db_process is not None:
self.db_process.kill()
self.db_process.wait()
timed_out = True
except KeyboardInterrupt: #developer pushed CTRL-C
......@@ -511,6 +298,10 @@ class Agent(object):
self.process.kill()
status = self.process.wait()
if self.db_process is not None:
self.db_process.kill()
self.db_process.wait()
finally:
self.server.stop.set()
......@@ -527,6 +318,13 @@ class Agent(object):
user_error = self.server.user_error,
)
process.rm()
if self.db_process is not None:
retval['stdout'] += '\n' + self.db_process.stdout
retval['stderr'] += '\n' + self.db_process.stderr
self.db_process.rm()
self.db_process = None
self.server = None
return retval
......
......@@ -43,162 +43,13 @@ from . import hash
from . import utils
from . import prototypes
class Storage(utils.CodeStorage):
"""Resolves paths for databases
Parameters:
prefix (str): Establishes the prefix of your installation.
name (str): The name of the database object in the format
``<name>/<version>``.
"""
def __init__(self, prefix, name):
if name.count('/') != 1:
raise RuntimeError("invalid database name: `%s'" % name)
self.name, self.version = name.split('/')
self.fullname = name
path = os.path.join(prefix, 'databases', name)
super(Storage, self).__init__(path, 'python') #views are coded in Python
class View(object):
'''A special loader class for database views, with specialized methods
Parameters:
db_name (str): The full name of the database object for this view
module (module): The preloaded module containing the database views as
returned by :py:func:`beat.core.loader.load_module`.
prefix (str, path): The prefix path for the current installation
root_folder (str, path): The path pointing to the root folder of this
database
exc (class): The class to use as base exception when translating the
exception from the user code. Read the documention of :py:func:`run`
for more details.
*args: Constructor parameters for the database view. Normally, none.
**kwargs: Constructor parameters for the database view. Normally, none.
'''
def __init__(self, module, definition, prefix, root_folder, exc=None,
*args, **kwargs):
try:
class_ = getattr(module, definition['view'])
except Exception as e:
if exc is not None:
type, value, traceback = sys.exc_info()
six.reraise(exc, exc(value), traceback)
else:
raise #just re-raise the user exception
self.obj = loader.run(class_, '__new__', exc, *args, **kwargs)
self.ready = False
self.prefix = prefix
self.root_folder = root_folder
self.definition = definition
self.exc = exc or RuntimeError
self.outputs = None
def prepare_outputs(self):
'''Prepares the outputs of the dataset'''
from .outputs import Output, OutputList
from .data import MemoryDataSink
from .dataformat import DataFormat
# create the stock outputs for this dataset, so data is dumped
# on a in-memory sink
self.outputs = OutputList()
for out_name, out_format in self.definition.get('outputs', {}).items():
data_sink = MemoryDataSink()
data_sink.dataformat = DataFormat(self.prefix, out_format)
data_sink.setup([])
self.outputs.add(Output(out_name, data_sink, dataset_output=True))
def setup(self, *args, **kwargs):
'''Sets up the view'''
kwargs.setdefault('root_folder', self.root_folder)
kwargs.setdefault('parameters', self.definition.get('parameters', {}))
if 'outputs' not in kwargs:
kwargs['outputs'] = self.outputs
else:
self.outputs = kwargs['outputs'] #record outputs nevertheless
self.ready = loader.run(self.obj, 'setup', self.exc, *args, **kwargs)
if not self.ready:
raise self.exc("unknow setup failure")
return self.ready
def input_group(self, name='default', exclude_outputs=[]):
'''A memory-source input group matching the outputs from the view'''
if not self.ready:
raise self.exc("database view not yet setup")
from .data import MemoryDataSource
from .outputs import SynchronizationListener
from .inputs import Input, InputGroup
# Setup the inputs
synchronization_listener = SynchronizationListener()
input_group = InputGroup(name,
synchronization_listener=synchronization_listener,
restricted_access=False)
for output in self.outputs:
if output.name in exclude_outputs: continue
data_source = MemoryDataSource(self.done, next_callback=self.next)
output.data_sink.data_sources.append(data_source)
input_group.add(Input(output.name,
output.data_sink.dataformat, data_source))
return input_group
def done(self, *args, **kwargs):
'''Checks if the view is done'''
if not self.ready:
raise self.exc("database view not yet setup")
return loader.run(self.obj, 'done', self.exc, *args, **kwargs)
from beat.backend.python.database import Storage
from beat.backend.python.database import View
from beat.backend.python.database import Database as BackendDatabase
def next(self, *args, **kwargs):
'''Runs through the next data chunk'''
if not self.ready:
raise self.exc("database view not yet setup")
return loader.run(self.obj, 'next', self.exc, *args, **kwargs)
def __getattr__(self, key):
'''Returns an attribute of the view - only called at last resort'''
return getattr(self.obj, key)
class Database(object):
class Database(BackendDatabase):
"""Databases define the start point of the dataflow in an experiment.
......@@ -240,20 +91,8 @@ class Database(object):
"""
def __init__(self, prefix, data, dataformat_cache=None):
super(Database, self).__init__(prefix, data, dataformat_cache)
self._name = None
self.storage = None
self.prefix = prefix
self.dataformats = {} # preloaded dataformats
self.errors = []
self.data = None
self.code = None
# if the user has not provided a cache, still use one for performance
dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
self._load(data, dataformat_cache)
def _load(self, data, dataformat_cache):
"""Loads the database"""
......@@ -353,111 +192,20 @@ class Database(object):
"unsupported by this version" % (_set['view'],)
)
@property
def name(self):
"""Returns the name of this object
"""
return self._name or '__unnamed_database__'
@name.setter
def name(self, value):
self._name = value
self.storage = Storage(self.prefix, value)
@property
def schema_version(self):
"""Returns the schema version"""
return self.data.get('schema_version', 1)
@property
def protocols(self):
"""The declaration of all the protocols of the database"""
data = self.data['protocols']
return dict(zip([k['name'] for k in data], data))
def protocol(self, name):
"""The declaration of a specific protocol in the database"""
return self.protocols[name]
@property
def protocol_names(self):
"""Names of protocols declared for this database"""
data = self.data['protocols']
return [k['name'] for k in data]
def sets(self, protocol):
"""The declaration of a specific set in the database protocol"""
data = self.protocol(protocol)['sets']
return dict(zip([k['name'] for k in data], data))
def set(self, protocol, name):
"""The declaration of all the protocols of the database"""
return self.sets(protocol)[name]
def set_names(self, protocol):
"""The names of sets in a given protocol for this database"""
data = self.protocol(protocol)['sets']
return [k['name'] for k in data]
@property
def valid(self):
return not bool(self.errors)