Commit 14bb7f2d authored by Philip ABBET's avatar Philip ABBET

Refactoring: the 'MessageHandler' class is now part of this package

parent 4e46c864
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.core module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
import logging
logger = logging.getLogger(__name__)
import gevent
import zmq.green as zmq
import requests
from gevent import monkey
monkey.patch_socket(dns=False)
monkey.patch_ssl()
from . import baseformat
class MessageHandler(gevent.Greenlet):
'''A 0MQ message handler for our communication with other processes
Support for more messages can be implemented by subclassing this class.
This one only support input-related messages.
'''
def __init__(self, input_list, zmq_context, zmq_socket):
super(MessageHandler, self).__init__()
# An event unblocking a graceful stop
self.stop = gevent.event.Event()
self.stop.clear()
self.must_kill = gevent.event.Event()
self.must_kill.clear()
# Starts our 0MQ server
self.context = zmq_context
self.socket = zmq_socket
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
self.input_list = input_list
self.system_error = ''
self.user_error = ''
self.last_statistics = {}
self.process = None
# implementations
self.callbacks = dict(
nxt = self.next,
hmd = self.has_more_data,
idd = self.is_dataunit_done,
don = self.done,
err = self.error,
)
def set_process(self, process):
self.process = process
self.process.statistics() # initialize internal statistics
def _run(self):
logger.debug("0MQ server thread started")
while not self.stop.is_set(): #keep on
if self.must_kill.is_set():
if self.process is not None:
self.process.kill()
self.must_kill.clear()
timeout = 1000 #ms
socks = dict(self.poller.poll(timeout)) #yields to the next greenlet
if self.socket in socks and socks[self.socket] == zmq.POLLIN:
# incomming
more = True
parts = []
while more:
parts.append(self.socket.recv())
more = self.socket.getsockopt(zmq.RCVMORE)
command = parts[0]
logger.debug("recv: %s", command)
if command in self.callbacks:
try: #to handle command
self.callbacks[command](*parts[1:])
except:
import traceback
parser = lambda s: s if len(s)<20 else s[:20] + '...'
parsed_parts = ' '.join([parser(k) for k in parts])
message = "A problem occurred while performing command `%s' " \
"killing user process. Exception:\n %s" % \
(parsed_parts, traceback.format_exc())
logger.error(message, exc_info=True)
self.system_error = message
if self.process is not None:
self.process.kill()
self.stop.set()
break
else:
message = "Command `%s' is not implemented - stopping user process" \
% command
logger.error(message)
self.system_error = message
if self.process is not None:
self.process.kill()
self.stop.set()
break
self.socket.setsockopt(zmq.LINGER, 0)
self.socket.close()
logger.debug("0MQ server thread stopped")
def _get_input_candidate(self, channel, name):
channel_group = self.input_list.group(channel)
retval = channel_group[name]
if retval is None:
raise RuntimeError("Could not find input `%s' at channel `%s'" % \
(name, channel))
return retval
def next(self, channel, name=None):
"""Syntax: nxt channel [name] ..."""
if name is not None: #single input
logger.debug('recv: nxt %s %s', channel, name)
input_candidate = self._get_input_candidate(channel, name)
input_candidate.next()
if input_candidate.data is None: #error
message = "User algorithm asked for more data for channel " \
"`%s' on input `%s', but it is over (no more data). This " \
"normally indicates a programming error on the user " \
"side." % (channel, name)
self.user_error += message + '\n'
raise RuntimeError(message)
if isinstance(input_candidate.data, baseformat.baseformat):
packed = input_candidate.data.pack()
else:
packed = input_candidate.data
logger.debug('send: <bin> (size=%d), indexes=(%d, %d)', len(packed),
input_candidate.data_index, input_candidate.data_index_end)
self.socket.send('%d' % input_candidate.data_index, zmq.SNDMORE)
self.socket.send('%d' % input_candidate.data_index_end, zmq.SNDMORE)
self.socket.send(packed)
else: #whole group data
logger.debug('recv: nxt %s', channel)
channel_group = self.input_list.group(channel)
# Call next() on the group
channel_group.restricted_access = False
channel_group.next()
channel_group.restricted_access = True
# Loop over the inputs
inputs_to_go = len(channel_group)
self.socket.send(str(inputs_to_go), zmq.SNDMORE)
for inp in channel_group:
logger.debug('send: %s', inp.name)
self.socket.send(str(inp.name), zmq.SNDMORE)
if inp.data is None:
message = "User algorithm process asked for more data on channel " \
"`%s' (all inputs), but input `%s' has nothing. This " \
"normally indicates a programming error on the user " \
"side." % (channel, inp.name)
self.user_error += message + '\n'
raise RuntimeError(message)
elif isinstance(inp.data, baseformat.baseformat):
packed = inp.data.pack()
else:
packed = inp.data
logger.debug('send: <bin> (size=%d), indexes=(%d, %d)', len(packed),
inp.data_index, inp.data_index_end)
self.socket.send('%d' % inp.data_index, zmq.SNDMORE)
self.socket.send('%d' % inp.data_index_end, zmq.SNDMORE)
inputs_to_go -= 1
if inputs_to_go > 0:
self.socket.send(packed, zmq.SNDMORE)
else:
self.socket.send(packed)
def has_more_data(self, channel, name=None):
"""Syntax: hmd channel [name]"""
if name: #single input
logger.debug('recv: hmd %s %s', channel, name)
input_candidate = self._get_input_candidate(channel, name)
what = 'tru' if input_candidate.hasMoreData() else 'fal'
else: #for all channel names
logger.debug('recv: hmd %s', channel)
channel_group = self.input_list.group(channel)
what = 'tru' if channel_group.hasMoreData() else 'fal'
logger.debug('send: %s', what)
self.socket.send(what)
def is_dataunit_done(self, channel, name):
"""Syntax: idd channel name"""
logger.debug('recv: idd %s %s', channel, name)
input_candidate = self._get_input_candidate(channel, name)
what = 'tru' if input_candidate.isDataUnitDone() else 'fal'
logger.debug('send: %s', what)
self.socket.send(what)
def _collect_statistics(self):
logger.debug('collecting user process statistics...')
if self.process is not None:
self.last_statistics = self.process.statistics()
def _acknowledge(self):
logger.debug('send: ack')
self.socket.send('ack')
logger.debug('setting stop condition for 0MQ server thread')
self.stop.set()
def done(self, wait_time=None):
"""Syntax: don"""
logger.debug('recv: don %s', wait_time)
if wait_time is not None:
self._collect_statistics()
# collect I/O stats from client
wait_time = float(wait_time)
self.last_statistics['data'] = dict(network=dict(wait_time=wait_time))
self._acknowledge()
def error(self, t, msg):
"""Syntax: err type message"""
logger.debug('recv: err %s <msg> (size=%d)', t, len(msg))
if t == 'usr': self.user_error = msg
else: self.system_error = msg
self._collect_statistics()
self.last_statistics['data'] = dict(network=dict(wait_time=0.))
self._acknowledge()
def kill(self):
self.must_kill.set()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment