#!/usr/bin/env python # vim: set fileencoding=utf-8 : ############################################################################### # # # Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ # # Contact: beat.support@idiap.ch # # # # This file is part of the beat.core module of the BEAT platform. # # # # Commercial License Usage # # Licensees holding valid commercial BEAT licenses may use this file in # # accordance with the terms contained in a written agreement between you # # and Idiap. For further information contact tto@idiap.ch # # # # Alternatively, this file may be used under the terms of the GNU Affero # # Public License version 3 as published by the Free Software and appearing # # in the file LICENSE.AGPL included in the packaging of this file. # # The BEAT platform is distributed in the hope that it will be useful, but # # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # # or FITNESS FOR A PARTICULAR PURPOSE. # # # # You should have received a copy of the GNU Affero Public License along # # with the BEAT platform. If not, see http://www.gnu.org/licenses/. # # # ############################################################################### import os import shutil import simplejson import glob import logging logger = logging.getLogger(__name__) import gevent import zmq.green as zmq import requests from gevent import monkey monkey.patch_socket(dns=False) monkey.patch_ssl() from . import utils from . import dock from . import baseformat from beat.backend.python.message_handler import MessageHandler class Server(MessageHandler): '''A 0MQ server for our communication with the user process''' def __init__(self, input_list, output_list, host_address): # Starts our 0MQ server self.context = zmq.Context() self.socket = self.context.socket(zmq.PAIR) self.address = 'tcp://' + host_address port = self.socket.bind_to_random_port(self.address, min_port=50000) self.address += ':%d' % port logger.debug("zmq server bound to `%s'", self.address) super(Server, self).__init__(input_list, self.context, self.socket) self.output_list = output_list # implementations self.callbacks.update(dict( wrt = self.write, idm = self.is_data_missing, oic = self.output_is_connected, )) def destroy(self): self.context.destroy() def __str__(self): return 'Server(%s)' % self.address def _get_output_candidate(self, name): retval = self.output_list[name] if retval is None: raise RuntimeError("Could not find output `%s'" % name) return retval def write(self, name, packed): """Syntax: wrt output data""" logger.debug('recv: wrt %s (size=%d)', name, len(packed)) # Get output object output_candidate = self._get_output_candidate(name) if output_candidate is None: raise RuntimeError("Could not find output `%s' to write to" % name) data = output_candidate.data_sink.dataformat.type() data.unpack(packed) output_candidate.write(data) logger.debug('send: ack') self.socket.send('ack') def is_data_missing(self, name): """Syntax: idm output""" logger.debug('recv: idm %s', name) output_candidate = self._get_output_candidate(name) what = 'tru' if output_candidate.isDataMissing() else 'fal' logger.debug('send: %s', what) self.socket.send(what) def output_is_connected(self, name): """Syntax: oic output""" logger.debug('recv: oic %s', name) output_candidate = self._get_output_candidate(name) what = 'tru' if output_candidate.isConnected() else 'fal' logger.debug('send: %s', what) self.socket.send(what) class Agent(object): '''Handles synchronous commands. We use the greenlets for this implementation. Objects of this class are in charge of three separate tasks: 1. Handling the execution of the user process (in a docker container) 3. Implementing a pipe-based API for I/O that the user process can query Parameters: virtual_memory_in_megabytes (int, optional): The amount of virtual memory (in Megabytes) available for the job. If set to zero, no limit will be applied. max_cpu_percent (int): The maximum amount of CPU usage allowed in a system. This number must be an integer number between 0 and ``100*number_of_cores`` in your system. For instance, if your system has 2 cores, this number can go between 0 and 200. If it is <= 0, then we don't track CPU usage. ''' def __init__(self, virtual_memory_in_megabytes, max_cpu_percent): self.virtual_memory_in_megabytes = virtual_memory_in_megabytes self.max_cpu_percent = max_cpu_percent self.tempdir = None self.db_tempdir = None self.process = None self.db_process = None self.server = None def __enter__(self): '''Start of context manager''' logger.debug("Entering processing context...") # Creates a temporary directory for the user process self.tempdir = utils.temporary_directory() logger.debug("Created temporary directory `%s'", self.tempdir) self.db_tempdir = utils.temporary_directory() logger.debug("Created temporary directory `%s'", self.db_tempdir) self.process = None self.db_process = None return self def __exit__(self, exc_type, exc_value, traceback): if self.tempdir is not None and os.path.exists(self.tempdir): shutil.rmtree(self.tempdir) self.tempdir = None if self.db_tempdir is not None and os.path.exists(self.db_tempdir): shutil.rmtree(self.db_tempdir) self.db_tempdir = None self.process = None self.db_process = None logger.debug("Exiting processing context...") def run(self, configuration, host, timeout_in_minutes=0, daemon=0, db_address=None): """Runs the algorithm code Parameters: configuration (object): A *valid*, preloaded :py:class:`beat.core.execution.Executor` object. host (:py:class:Host): A configured docker host that will execute the user process. If the host does not have access to the required environment, an exception will be raised. timeout_in_minutes (int): The number of minutes to wait for the user process to execute. After this amount of time, the user process is killed with :py:attr:`signal.SIGKILL`. If set to zero, no timeout will be applied. daemon (int): If this variable is set, then we don't really start the user process, but just kick out 0MQ server, print the command-line and sleep for that many seconds. You're supposed to start the client by hand then and debug it. """ # Recursively copies configuration data to /prefix configuration.dump_runner_configuration(self.tempdir) if db_address is not None: configuration.dump_databases_provider_configuration(self.db_tempdir) # Modify the paths to the databases in the dumped configuration files root_folder = os.path.join(self.db_tempdir, 'prefix', 'databases') database_paths = {} if not configuration.data.has_key('datasets_root_path'): for db_name in configuration.databases.keys(): json_path = os.path.join(root_folder, db_name + '.json') with open(json_path, 'r') as f: db_data = simplejson.load(f) database_paths[db_name] = db_data['root_folder'] db_data['root_folder'] = os.path.join('/databases', db_name) with open(json_path, 'w') as f: simplejson.dump(db_data, f, indent=4) # Server for our single client self.server = Server(configuration.input_list, configuration.output_list, host.ip) # Figures out the images to use envkey = '%(name)s (%(version)s)' % configuration.data['environment'] if envkey not in host: raise RuntimeError("Environment `%s' is not available on docker " \ "host `%s' - available environments are %s" % (envkey, host, ", ".join(host.environments.keys()))) if db_address is not None: try: db_envkey = host.db2docker(database_paths.keys()) except: raise RuntimeError("No environment found for the databases `%s' " \ "- available environments are %s" % ( ", ".join(database_paths.keys()), ", ".join(host.db_environments.keys()))) # Launches the process (0MQ client) tmp_dir = os.path.join('/tmp', os.path.basename(self.tempdir)) cmd = ['execute', self.server.address, tmp_dir] if logger.getEffectiveLevel() <= logging.DEBUG: cmd.insert(1, '--debug') if daemon > 0: image = host.env2docker(envkey) logger.debug("Daemon mode: start the user process with the following " \ "command: `docker run -ti %s %s'", image, ' '.join(cmd)) cmd = ['sleep', str(daemon)] logger.debug("Daemon mode: sleeping for %d seconds", daemon) else: if db_address is not None: tmp_dir = os.path.join('/tmp', os.path.basename(self.db_tempdir)) db_cmd = ['databases_provider', db_address, tmp_dir] volumes = {} if not configuration.data.has_key('datasets_root_path'): for db_name, db_path in database_paths.items(): volumes[db_path] = { 'bind': os.path.join('/databases', db_name), 'mode': 'ro', } else: volumes[configuration.data['datasets_root_path']] = { 'bind': configuration.data['datasets_root_path'], 'mode': 'ro', } # Note: we only support one databases image loaded at the same time self.db_process = dock.Popen( host, db_envkey, command=db_cmd, tmp_archive=self.db_tempdir, volumes=volumes ) volumes = {} if not configuration.proxy_mode: volumes[configuration.cache] = { 'bind': '/cache', 'mode': 'rw', } # for name, details in configuration.data['inputs'].items(): # if 'database' in details: # continue # # basename = os.path.join(configuration.cache, details['path']) # filenames = glob.glob(basename + '*.data') # filenames.extend(glob.glob(basename + '*.data.checksum')) # filenames.extend(glob.glob(basename + '*.data.index')) # filenames.extend(glob.glob(basename + '*.data.index.checksum')) # # for filename in filenames: # volumes[filename] = { # 'bind': os.path.join('/cache', filename.replace(configuration.cache + '/', '')), # 'mode': 'ro', # } # # if 'result' in configuration.data: # outputs_config = { # 'result': configuration.data['result'] # } # else: # outputs_config = configuration.data['outputs'] # # for name, details in outputs_config.items(): # basename = os.path.join(configuration.cache, details['path']) # dirname = os.path.dirname(basename) # # volumes[dirname] = { # 'bind': os.path.join('/cache', dirname.replace(configuration.cache + '/', '')), # 'mode': 'rw', # } self.process = dock.Popen( host, envkey, command=cmd, tmp_archive=self.tempdir, virtual_memory_in_megabytes=self.virtual_memory_in_megabytes, max_cpu_percent=self.max_cpu_percent, volumes=volumes ) # provide a tip on how to stop the test if daemon > 0: logger.debug("To stop the daemon, press CTRL-c or kill the user " \ "process with `docker kill %s`", self.process.pid) # Serve asynchronously self.server.set_process(self.process) self.server.start() timed_out = False try: timeout = (60*timeout_in_minutes) if timeout_in_minutes else None status = self.process.wait(timeout) except requests.exceptions.ReadTimeout: logger.warn("user process has timed out after %d minutes", timeout_in_minutes) self.process.kill() status = self.process.wait() if self.db_process is not None: self.db_process.kill() self.db_process.wait() timed_out = True except KeyboardInterrupt: #developer pushed CTRL-C logger.info("stopping user process on CTRL-C console request") self.process.kill() status = self.process.wait() if self.db_process is not None: self.db_process.kill() self.db_process.wait() finally: self.server.stop.set() # Collects final information and returns to caller process = self.process self.process = None retval = dict( stdout = process.stdout, stderr = process.stderr, status = status, timed_out = timed_out, statistics = self.server.last_statistics, system_error = self.server.system_error, user_error = self.server.user_error, ) process.rm() if self.db_process is not None: retval['stdout'] += '\n' + self.db_process.stdout retval['stderr'] += '\n' + self.db_process.stderr self.db_process.rm() self.db_process = None self.server.destroy() self.server = None return retval def kill(self): """Stops the user process by force - to be called from signal handlers""" if self.server is not None: self.server.kill()