From 33a1ce0176e989b5a9e0577dc25387f4148d1d4b Mon Sep 17 00:00:00 2001
From: Philip ABBET <philip.abbet@idiap.ch>
Date: Tue, 16 May 2017 11:23:58 +0200
Subject: [PATCH] Better error handling for the remote inputs

---
 beat/backend/python/dbexecution.py     |  0
 beat/backend/python/inputs.py          | 62 +++++++++++++++++++++++++-
 beat/backend/python/message_handler.py | 33 +++++++++++++-
 3 files changed, 92 insertions(+), 3 deletions(-)
 mode change 100644 => 100755 beat/backend/python/dbexecution.py

diff --git a/beat/backend/python/dbexecution.py b/beat/backend/python/dbexecution.py
old mode 100644
new mode 100755
diff --git a/beat/backend/python/inputs.py b/beat/backend/python/inputs.py
index dec8aa6..2b43b24 100755
--- a/beat/backend/python/inputs.py
+++ b/beat/backend/python/inputs.py
@@ -251,6 +251,31 @@ class InputGroup:
 #----------------------------------------------------------
 
 
+class RemoteException(Exception):
+
+    def __init__(self, kind, message):
+        super(RemoteException, self).__init__()
+
+        if kind == 'sys':
+            self.system_error = message
+            self.user_error = ''
+        else:
+            self.system_error = ''
+            self.user_error = message
+
+
+#----------------------------------------------------------
+
+
+def process_error(socket):
+    kind = socket.recv()
+    message = socket.recv()
+    raise RemoteException(kind, message)
+
+
+#----------------------------------------------------------
+
+
 class RemoteInput:
   """Allows to access the input of a processing block, via a socket.
 
@@ -297,13 +322,18 @@ class RemoteInput:
     """Indicates if the current data unit will change at the next iteration"""
 
     logger.debug('send: (idd) is-dataunit-done %s', self.name)
+
     _start = time.time()
+
     self.socket.send('idd', zmq.SNDMORE)
     self.socket.send(self.group.channel, zmq.SNDMORE)
     self.socket.send(self.name)
+
     answer = self.socket.recv()
+
     self.comm_time += time.time() - _start
     logger.debug('recv: %s', answer)
+
     return answer == 'tru'
 
 
@@ -311,13 +341,21 @@ class RemoteInput:
     """Indicates if there is more data to process on the input"""
 
     logger.debug('send: (hmd) has-more-data %s %s', self.group.channel, self.name)
+
     _start = time.time()
+
     self.socket.send('hmd', zmq.SNDMORE)
     self.socket.send(self.group.channel, zmq.SNDMORE)
     self.socket.send(self.name)
+
     answer = self.socket.recv()
+
     self.comm_time += time.time() - _start
     logger.debug('recv: %s', answer)
+
+    if answer == 'err':
+        process_error(self.socket)
+
     return answer == 'tru'
 
 
@@ -325,13 +363,23 @@ class RemoteInput:
     """Retrieves the next block of data"""
 
     logger.debug('send: (nxt) next %s %s', self.group.channel, self.name)
+
     _start = time.time()
+
     self.socket.send('nxt', zmq.SNDMORE)
     self.socket.send(self.group.channel, zmq.SNDMORE)
     self.socket.send(self.name)
-    self.data_index = int(self.socket.recv())
+
+    answer = self.socket.recv()
+
+    if answer == 'err':
+        self.comm_time += time.time() - _start
+        process_error(self.socket)
+
+    self.data_index = int(answer)
     self.data_index_end = int(self.socket.recv())
     self.unpack(self.socket.recv())
+
     self.comm_time += time.time() - _start
     self.nb_data_blocks_read += 1
 
@@ -434,12 +482,20 @@ class RemoteInputGroup:
     """Indicates if there is more data to process in the group"""
 
     logger.debug('send: (hmd) has-more-data %s', self.channel)
+
     _start = time.time()
+
     self.socket.send('hmd', zmq.SNDMORE)
     self.socket.send(self.channel)
+
     answer = self.socket.recv()
+
     self.comm_time += time.time() - _start
     logger.debug('recv: %s', answer)
+
+    if answer == 'err':
+        process_error(self.socket)
+
     return answer == 'tru'
 
 
@@ -456,6 +512,10 @@ class RemoteInputGroup:
     parts = []
     while more:
       parts.append(self.socket.recv())
+      if parts[-1] == 'err':
+        self.comm_time += time.time() - _start
+        process_error(self.socket)
+
       more = self.socket.getsockopt(zmq.RCVMORE)
 
     n = int(parts.pop(0))
diff --git a/beat/backend/python/message_handler.py b/beat/backend/python/message_handler.py
index b643e55..69470da 100755
--- a/beat/backend/python/message_handler.py
+++ b/beat/backend/python/message_handler.py
@@ -124,6 +124,7 @@ class MessageHandler(gevent.Greenlet):
                     "killing user process. Exception:\n %s" % \
                     (parsed_parts, traceback.format_exc())
             logger.error(message, exc_info=True)
+            self.send_error(message)
             self.system_error = message
             if self.process is not None:
               self.process.kill()
@@ -134,14 +135,13 @@ class MessageHandler(gevent.Greenlet):
           message = "Command `%s' is not implemented - stopping user process" \
                   % command
           logger.error(message)
+          self.send_error(message)
           self.system_error = message
           if self.process is not None:
             self.process.kill()
           self.stop.set()
           break
 
-    self.socket.setsockopt(zmq.LINGER, 0)
-    self.socket.close()
     logger.debug("0MQ server thread stopped")
 
 
@@ -289,3 +289,32 @@ class MessageHandler(gevent.Greenlet):
 
   def kill(self):
     self.must_kill.set()
+
+
+  def send_error(self, message):
+    """Sends a user (usr) or system (sys) error message to the infrastructure"""
+  
+    logger.debug('send: (err) error')
+    self.socket.send('err', zmq.SNDMORE)
+    self.socket.send('usr', zmq.SNDMORE)
+    logger.debug('send: """%s"""' % message.rstrip())
+    self.socket.send(message)
+  
+    poller = zmq.Poller()
+    poller.register(self.socket, zmq.POLLIN)
+  
+    this_try = 1
+    max_tries = 5
+    timeout = 1000 #ms
+    while this_try <= max_tries:
+      socks = dict(poller.poll(timeout)) #blocks here, for 5 seconds at most
+      if self.socket in socks and socks[self.socket] == zmq.POLLIN:
+        answer = self.socket.recv() #ack
+        logger.debug('recv: %s', answer)
+        break
+      logger.warn('(try %d) waited %d ms for "ack" from server',
+              this_try, timeout)
+      this_try += 1
+      if this_try > max_tries:
+        logger.error('could not send error message to server')
+        logger.error('stopping 0MQ client anyway')
-- 
GitLab