Commit 6b8a780b authored by Samuel GAIST's avatar Samuel GAIST
Browse files

[scripts][execute] pre-commit cleanup

parent 122f0a4f
...@@ -63,10 +63,8 @@ import logging ...@@ -63,10 +63,8 @@ import logging
import os import os
import sys import sys
import docopt import docopt
import pwd
import stat
import simplejson import simplejson
import subprocess import subprocess # nosec
import zmq import zmq
...@@ -74,14 +72,14 @@ from beat.backend.python.execution import AlgorithmExecutor ...@@ -74,14 +72,14 @@ from beat.backend.python.execution import AlgorithmExecutor
from beat.backend.python.exceptions import UserError from beat.backend.python.exceptions import UserError
#---------------------------------------------------------- # ----------------------------------------------------------
def send_error(logger, socket, tp, message): def send_error(logger, socket, tp, message):
"""Sends a user (usr) or system (sys) error message to the infrastructure""" """Sends a user (usr) or system (sys) error message to the infrastructure"""
logger.debug('send: (err) error') logger.debug("send: (err) error")
socket.send_string('err', zmq.SNDMORE) socket.send_string("err", zmq.SNDMORE)
socket.send_string(tp, zmq.SNDMORE) socket.send_string(tp, zmq.SNDMORE)
logger.debug('send: """%s"""' % message.rstrip()) logger.debug('send: """%s"""' % message.rstrip())
socket.send_string(message) socket.send_string(message)
...@@ -91,22 +89,21 @@ def send_error(logger, socket, tp, message): ...@@ -91,22 +89,21 @@ def send_error(logger, socket, tp, message):
this_try = 1 this_try = 1
max_tries = 5 max_tries = 5
timeout = 1000 #ms timeout = 1000 # ms
while this_try <= max_tries: while this_try <= max_tries:
socks = dict(poller.poll(timeout)) #blocks here, for 5 seconds at most socks = dict(poller.poll(timeout)) # blocks here, for 5 seconds at most
if socket in socks and socks[socket] == zmq.POLLIN: if socket in socks and socks[socket] == zmq.POLLIN:
answer = socket.recv() #ack answer = socket.recv() # ack
logger.debug('recv: %s', answer) logger.debug("recv: %s", answer)
break break
logger.warn('(try %d) waited %d ms for "ack" from server', logger.warn('(try %d) waited %d ms for "ack" from server', this_try, timeout)
this_try, timeout)
this_try += 1 this_try += 1
if this_try > max_tries: if this_try > max_tries:
logger.error('could not send error message to server') logger.error("could not send error message to server")
logger.error('stopping 0MQ client anyway') logger.error("stopping 0MQ client anyway")
#---------------------------------------------------------- # ----------------------------------------------------------
def close(logger, sockets, context): def close(logger, sockets, context):
...@@ -119,25 +116,25 @@ def close(logger, sockets, context): ...@@ -119,25 +116,25 @@ def close(logger, sockets, context):
logger.debug("0MQ client finished") logger.debug("0MQ client finished")
#---------------------------------------------------------- # ----------------------------------------------------------
def process_traceback(tb, prefix): def process_traceback(tb, prefix):
import traceback import traceback
algorithms_prefix = os.path.join(prefix, 'algorithms') + os.sep algorithms_prefix = os.path.join(prefix, "algorithms") + os.sep
for first_line, line in enumerate(tb): for first_line, line in enumerate(tb):
if line[0].startswith(algorithms_prefix): if line[0].startswith(algorithms_prefix):
break break
s = ''.join(traceback.format_list(tb[first_line:])) s = "".join(traceback.format_list(tb[first_line:]))
s = s.replace(algorithms_prefix, '').strip() s = s.replace(algorithms_prefix, "").strip()
return s return s
#---------------------------------------------------------- # ----------------------------------------------------------
def main(): def main():
...@@ -147,101 +144,112 @@ def main(): ...@@ -147,101 +144,112 @@ def main():
# to different processing phases of this script # to different processing phases of this script
""" """
package = __name__.rsplit('.', 2)[0] package = __name__.rsplit(".", 2)[0]
version = package + ' v' + \ version = package + " v" + __import__("pkg_resources").require(package)[0].version
__import__('pkg_resources').require(package)[0].version
prog = os.path.basename(sys.argv[0]) prog = os.path.basename(sys.argv[0])
args = docopt.docopt(__doc__ % dict(prog=prog, version=version), args = docopt.docopt(__doc__ % dict(prog=prog, version=version), version=version)
version=version)
# Setup the logging system # Setup the logging system
formatter = logging.Formatter(fmt="[%(asctime)s - execute.py - " \ formatter = logging.Formatter(
"%(name)s] %(levelname)s: %(message)s", fmt="[%(asctime)s - execute.py - " "%(name)s] %(levelname)s: %(message)s",
datefmt="%d/%b/%Y %H:%M:%S") datefmt="%d/%b/%Y %H:%M:%S",
)
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(formatter) handler.setFormatter(formatter)
root_logger = logging.getLogger('beat.backend.python') root_logger = logging.getLogger("beat.backend.python")
root_logger.addHandler(handler) root_logger.addHandler(handler)
if args['--debug']: if args["--debug"]:
root_logger.setLevel(logging.DEBUG) root_logger.setLevel(logging.DEBUG)
else: else:
root_logger.setLevel(logging.INFO) root_logger.setLevel(logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Retrieve the cache path # Retrieve the cache path
cache = args['--cache'] if args['--cache'] is not None else '/cache' cache = args["--cache"] if args["--cache"] is not None else "/cache"
# Creates the 0MQ socket for communication with BEAT # Creates the 0MQ socket for communication with BEAT
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.PAIR) socket = context.socket(zmq.PAIR)
address = args['<addr>'] address = args["<addr>"]
socket.connect(address) socket.connect(address)
logger.debug("zmq client connected to `%s'", address) logger.debug("zmq client connected to `%s'", address)
# Creates the 0MQ socket for communication with the databases (if necessary) # Creates the 0MQ socket for communication with the databases (if necessary)
db_socket = None db_socket = None
if args['<db_addr>']: if args["<db_addr>"]:
db_socket = context.socket(zmq.PAIR) db_socket = context.socket(zmq.PAIR)
db_socket.connect(args['<db_addr>']) db_socket.connect(args["<db_addr>"])
logger.debug("zmq client connected to db `%s'", args['<db_addr>']) logger.debug("zmq client connected to db `%s'", args["<db_addr>"])
loop_socket = None loop_socket = None
if args['<loop_addr>']: if args["<loop_addr>"]:
loop_socket = context.socket(zmq.PAIR) loop_socket = context.socket(zmq.PAIR)
loop_socket.connect(args['<loop_addr>']) loop_socket.connect(args["<loop_addr>"])
logger.debug("zmq client connected to loop `%s'", args['<loop_addr>']) logger.debug("zmq client connected to loop `%s'", args["<loop_addr>"])
# Check the dir # Check the dir
if not os.path.exists(args['<dir>']): if not os.path.exists(args["<dir>"]):
send_error(logger, socket, 'sys', "Running directory `%s' not found" % args['<dir>']) send_error(
logger, socket, "sys", "Running directory `%s' not found" % args["<dir>"]
)
close(logger, [socket, db_socket, loop_socket], context) close(logger, [socket, db_socket, loop_socket], context)
return 1 return 1
# Load the configuration # Load the configuration
with open(os.path.join(args['<dir>'], 'configuration.json'), 'r') as f: with open(os.path.join(args["<dir>"], "configuration.json"), "r") as f:
cfg = simplejson.load(f) cfg = simplejson.load(f)
# Create a new user with less privileges (if necessary) # Create a new user with less privileges (if necessary)
if os.getuid() != cfg['uid']: if os.getuid() != cfg["uid"]:
retcode = subprocess.call(['adduser', '--uid', str(cfg['uid']), retcode = subprocess.call( # nosec
'--no-create-home', '--disabled-password', [
'--disabled-login', '--gecos', '""', '-q', "adduser",
'beat-nobody']) "--uid",
str(cfg["uid"]),
"--no-create-home",
"--disabled-password",
"--disabled-login",
"--gecos",
'""',
"-q",
"beat-nobody",
]
)
if retcode != 0: if retcode != 0:
send_error(logger, socket, 'sys', 'Failed to create an user with the UID %d' % cfg['uid']) send_error(
logger,
socket,
"sys",
"Failed to create an user with the UID %d" % cfg["uid"],
)
close(logger, [socket, db_socket, loop_socket], context) close(logger, [socket, db_socket, loop_socket], context)
return 1 return 1
# Change to the user with less privileges # Change to the user with less privileges
try: try:
os.setgid(cfg['uid']) os.setgid(cfg["uid"])
os.setuid(cfg['uid']) os.setuid(cfg["uid"])
except: except Exception:
import traceback import traceback
send_error(logger, socket, 'sys', traceback.format_exc())
send_error(logger, socket, "sys", traceback.format_exc())
close(logger, [socket, db_socket, loop_socket], context) close(logger, [socket, db_socket, loop_socket], context)
return 1 return 1
try: try:
# Sets up the execution # Sets up the execution
executor = AlgorithmExecutor(socket, executor = AlgorithmExecutor(
args['<dir>'], socket,
cache_root=cache, args["<dir>"],
db_socket=db_socket, cache_root=cache,
loop_socket=loop_socket) db_socket=db_socket,
loop_socket=loop_socket,
)
try: try:
status = executor.setup() status = executor.setup()
...@@ -251,6 +259,7 @@ def main(): ...@@ -251,6 +259,7 @@ def main():
raise raise
except Exception as e: except Exception as e:
import traceback import traceback
exc_type, exc_value, exc_traceback = sys.exc_info() exc_type, exc_value, exc_traceback = sys.exc_info()
tb = traceback.extract_tb(exc_traceback) tb = traceback.extract_tb(exc_traceback)
s = process_traceback(tb, executor.prefix) s = process_traceback(tb, executor.prefix)
...@@ -265,6 +274,7 @@ def main(): ...@@ -265,6 +274,7 @@ def main():
raise raise
except Exception as e: except Exception as e:
import traceback import traceback
exc_type, exc_value, exc_traceback = sys.exc_info() exc_type, exc_value, exc_traceback = sys.exc_info()
tb = traceback.extract_tb(exc_traceback) tb = traceback.extract_tb(exc_traceback)
s = process_traceback(tb, executor.prefix) s = process_traceback(tb, executor.prefix)
...@@ -279,27 +289,31 @@ def main(): ...@@ -279,27 +289,31 @@ def main():
raise raise
except Exception as e: except Exception as e:
import traceback import traceback
exc_type, exc_value, exc_traceback = sys.exc_info() exc_type, exc_value, exc_traceback = sys.exc_info()
tb = traceback.extract_tb(exc_traceback) tb = traceback.extract_tb(exc_traceback)
s = process_traceback(tb, executor.prefix) s = process_traceback(tb, executor.prefix)
raise UserError("%s%s: %s" % (s, type(e).__name__, e)) raise UserError("%s%s: %s" % (s, type(e).__name__, e))
except UserError as e: except UserError as e:
send_error(logger, socket, 'usr', str(e)) send_error(logger, socket, "usr", str(e))
return 1 return 1
except MemoryError as e: except MemoryError:
# Say something meaningful to the user # Say something meaningful to the user
msg = "The user process for this block ran out of memory. We " \ msg = (
"suggest you optimise your code to reduce memory usage or, " \ "The user process for this block ran out of memory. We "
"if this is not an option, choose an appropriate processing " \ "suggest you optimise your code to reduce memory usage or, "
"if this is not an option, choose an appropriate processing "
"queue with enough memory." "queue with enough memory."
send_error(logger, socket, 'usr', msg) )
send_error(logger, socket, "usr", msg)
return 1 return 1
except Exception as e: except Exception:
import traceback import traceback
send_error(logger, socket, 'sys', traceback.format_exc())
send_error(logger, socket, "sys", traceback.format_exc())
return 1 return 1
finally: finally:
...@@ -308,8 +322,8 @@ def main(): ...@@ -308,8 +322,8 @@ def main():
return 0 return 0
#---------------------------------------------------------- # ----------------------------------------------------------
if __name__ == '__main__': if __name__ == "__main__":
sys.exit(main()) sys.exit(main())
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment