From a9164accbdbcc0a0792ee86459677e35a2c11953 Mon Sep 17 00:00:00 2001 From: Philip ABBET <philip.abbet@idiap.ch> Date: Wed, 6 Dec 2017 17:22:16 +0100 Subject: [PATCH] Refactoring: the 'Executor' class now supports sequential and autonomous algorithms --- beat/backend/python/executor.py | 108 ++++++++---- beat/backend/python/helpers.py | 182 +++++++++++++------- beat/backend/python/test/test_executor.py | 200 ++++++++++++++++++++++ 3 files changed, 398 insertions(+), 92 deletions(-) create mode 100644 beat/backend/python/test/test_executor.py diff --git a/beat/backend/python/executor.py b/beat/backend/python/executor.py index df0b929..7d7e865 100755 --- a/beat/backend/python/executor.py +++ b/beat/backend/python/executor.py @@ -37,13 +37,11 @@ import time import zmq import simplejson -from . import algorithm -from . import inputs -from . import outputs -from . import stats +from .algorithm import Algorithm from .helpers import create_inputs_from_configuration from .helpers import create_outputs_from_configuration from .helpers import AccessMode +from . import stats class Executor(object): @@ -90,34 +88,47 @@ class Executor(object): self.prefix = os.path.join(directory, 'prefix') self.runner = None - # temporary caches, if the user has not set them, for performance + # Temporary caches, if the user has not set them, for performance database_cache = database_cache if database_cache is not None else {} dataformat_cache = dataformat_cache if dataformat_cache is not None else {} library_cache = library_cache if library_cache is not None else {} - self.algorithm = algorithm.Algorithm(self.prefix, self.data['algorithm'], - dataformat_cache, library_cache) + # Load the algorithm + self.algorithm = Algorithm(self.prefix, self.data['algorithm'], + dataformat_cache, library_cache) - # Use algorithm names for inputs and outputs main_channel = self.data['channel'] - # Loads algorithm inputs - if self.data['proxy_mode']: - cache_access = AccessMode.REMOTE - else: - cache_access = AccessMode.LOCAL + if self.algorithm.type == Algorithm.LEGACY: + # Loads algorithm inputs + if self.data['proxy_mode']: + cache_access = AccessMode.REMOTE + else: + cache_access = AccessMode.LOCAL - (self.input_list, _) = create_inputs_from_configuration( - self.data, self.algorithm, self.prefix, cache_root, - cache_access=cache_access, db_access=AccessMode.REMOTE, - socket=self.socket - ) + (self.input_list, self.data_loaders, _) = create_inputs_from_configuration( + self.data, self.algorithm, self.prefix, cache_root, + cache_access=cache_access, db_access=AccessMode.REMOTE, + socket=self.socket + ) - # Loads algorithm outputs - (self.output_list, _) = create_outputs_from_configuration( - self.data, self.algorithm, self.prefix, cache_root, self.input_list, - cache_access=cache_access, socket=self.socket - ) + # Loads algorithm outputs + (self.output_list, _) = create_outputs_from_configuration( + self.data, self.algorithm, self.prefix, cache_root, self.input_list, + cache_access=cache_access, socket=self.socket + ) + + else: + (self.input_list, self.data_loaders, _) = create_inputs_from_configuration( + self.data, self.algorithm, self.prefix, cache_root, + cache_access=AccessMode.LOCAL, db_access=AccessMode.REMOTE + ) + + # Loads algorithm outputs + (self.output_list, _) = create_outputs_from_configuration( + self.data, self.algorithm, self.prefix, cache_root, self.input_list, + cache_access=AccessMode.LOCAL + ) def setup(self): @@ -129,27 +140,56 @@ class Executor(object): return retval + def prepare(self): + """Prepare the algorithm""" + + self.runner = self.algorithm.runner() + retval = self.runner.prepare(self.data_loaders) + logger.debug("User algorithm is prepared") + return retval + + def process(self): """Executes the user algorithm code using the current interpreter. """ - if not self.input_list or not self.output_list: - raise RuntimeError("I/O for execution block has not yet been set up") - - while self.input_list.hasMoreData(): - main_group = self.input_list.main_group - main_group.restricted_access = False - main_group.next() - main_group.restricted_access = True - + if self.algorithm.type == Algorithm.AUTONOMOUS: if self.analysis: - result = self.runner.process(inputs=self.input_list, output=self.output_list[0]) + result = self.runner.process(data_loaders=self.data_loaders, + output=self.output_list[0]) else: - result = self.runner.process(inputs=self.input_list, outputs=self.output_list) + result = self.runner.process(data_loaders=self.data_loaders, + outputs=self.output_list) if not result: return False + else: + while self.input_list.hasMoreData(): + main_group = self.input_list.main_group + main_group.restricted_access = False + main_group.next() + main_group.restricted_access = True + + if self.algorithm.type == Algorithm.LEGACY: + if self.analysis: + result = self.runner.process(inputs=self.input_list, output=self.output_list[0]) + else: + result = self.runner.process(inputs=self.input_list, outputs=self.output_list) + + elif self.algorithm.type == Algorithm.SEQUENTIAL: + if self.analysis: + result = self.runner.process(inputs=self.input_list, + data_loaders=self.data_loaders, + output=self.output_list[0]) + else: + result = self.runner.process(inputs=self.input_list, + data_loaders=self.data_loaders, + outputs=self.output_list) + + if not result: + return False + for output in self.output_list: output.close() diff --git a/beat/backend/python/helpers.py b/beat/backend/python/helpers.py index 9ac010d..e82bf19 100755 --- a/beat/backend/python/helpers.py +++ b/beat/backend/python/helpers.py @@ -32,11 +32,26 @@ import errno import logging logger = logging.getLogger(__name__) -from . import data -from . import inputs -from . import outputs +from .data import MemoryDataSource +from .data import CachedDataSource +from .data import CachedFileLoader +from .data import CachedDataSink +from .data import getAllFilenames +from .data_loaders import DataLoaderList +from .data_loaders import DataLoader +from .inputs import InputList +from .inputs import Input +from .inputs import RemoteInput +from .inputs import InputGroup +from .outputs import SynchronizationListener +from .outputs import OutputList +from .outputs import Output +from .outputs import RemoteOutput +from .algorithm import Algorithm +#---------------------------------------------------------- + def convert_experiment_configuration_to_container(config, proxy_mode): data = { @@ -80,13 +95,77 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root, data_sources = [] views = {} - input_list = inputs.InputList() + input_list = InputList() + data_loader_list = DataLoaderList() # This is used for parallelization purposes start_index, end_index = config.get('range', (None, None)) + + def _create_local_input(details): + data_source = CachedDataSource() + data_sources.append(data_source) + + filename = os.path.join(cache_root, details['path'] + '.data') + + if details['channel'] == config['channel']: # synchronized + status = data_source.setup( + filename=filename, + prefix=prefix, + force_start_index=start_index, + force_end_index=end_index, + unpack=True, + ) + else: + status = data_source.setup( + filename=filename, + prefix=prefix, + unpack=True, + ) + + if not status: + raise IOError("cannot load cache file `%s'" % details['path']) + + input = Input(name, algorithm.input_map[name], data_source) + + logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \ + (name, details['channel'], algorithm.input_map[name], filename)) + + return input + + + def _create_data_loader(details): + filename = os.path.join(cache_root, details['path'] + '.data') + + data_loader = data_loader_list[details['channel']] + if data_loader is None: + data_loader = DataLoader(details['channel']) + data_loader_list.add(data_loader) + + logger.debug("Data loader created: group='%s'" % details['channel']) + + cached_file = CachedFileLoader() + result = cached_file.setup( + filename=filename, + prefix=prefix, + start_index=start_index, + end_index=end_index, + unpack=True, + ) + + if not result: + raise IOError("cannot load cache file `%s'" % details['path']) + + data_loader.add(name, cached_file) + + logger.debug("Input '%s' added to data loader: group='%s', dataformat='%s', filename='%s'" % \ + (name, details['channel'], algorithm.input_map[name], filename)) + + for name, details in config['inputs'].items(): + input = None + if details.get('database', False): if db_access == AccessMode.LOCAL: if databases is None: @@ -114,12 +193,12 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root, view = views[channel] # Creation of the input - data_source = data.MemoryDataSource(view.done, next_callback=view.next) + data_source = MemoryDataSource(view.done, next_callback=view.next) output = view.outputs[details['output']] output.data_sink.data_sources.append(data_source) - input = inputs.Input(name, algorithm.input_map[name], data_source) + input = Input(name, algorithm.input_map[name], data_source) logger.debug("Input '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \ (name, channel, algorithm.input_map[name], details['database'], @@ -129,47 +208,34 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root, if socket is None: raise IOError("No socket provided for remote inputs") - input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]], - socket, unpack=unpack) + input = RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]], + socket, unpack=unpack) logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s', connected to a database" % \ (name, details['channel'], algorithm.input_map[name])) + elif cache_access == AccessMode.LOCAL: - data_source = data.CachedDataSource() - data_sources.append(data_source) - - filename = os.path.join(cache_root, details['path'] + '.data') - - if details['channel'] == config['channel']: # synchronized - status = data_source.setup( - filename=filename, - prefix=prefix, - force_start_index=start_index, - force_end_index=end_index, - unpack=True, - ) - else: - status = data_source.setup( - filename=filename, - prefix=prefix, - unpack=True, - ) - if not status: - raise IOError("cannot load cache file `%s'" % details['path']) + if algorithm.type == Algorithm.LEGACY: + input = _create_local_input(details) - input = inputs.Input(name, algorithm.input_map[name], data_source) + elif algorithm.type == Algorithm.SEQUENTIAL: + if details['channel'] == config['channel']: # synchronized + input = _create_local_input(details) + else: + _create_data_loader(details) + + elif algorithm.type == Algorithm.AUTONOMOUS: + _create_data_loader(details) - logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \ - (name, details['channel'], algorithm.input_map[name], filename)) elif cache_access == AccessMode.REMOTE: if socket is None: raise IOError("No socket provided for remote inputs") - input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]], - socket, unpack=unpack) + input = RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]], + socket, unpack=unpack) logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s'" % \ (name, details['channel'], algorithm.input_map[name])) @@ -178,24 +244,24 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root, continue # Synchronization bits - group = input_list.group(details['channel']) - if group is None: - synchronization_listener = None - if not no_synchronisation_listeners: - synchronization_listener = outputs.SynchronizationListener() - - group = inputs.InputGroup( - details['channel'], - synchronization_listener=synchronization_listener, - restricted_access=(details['channel'] == config['channel']) - ) - input_list.add(group) - logger.debug("Group '%s' created" % details['channel']) + if input is not None: + group = input_list.group(details['channel']) + if group is None: + synchronization_listener = None + if not no_synchronisation_listeners: + synchronization_listener = SynchronizationListener() - group.add(input) + group = InputGroup( + details['channel'], + synchronization_listener=synchronization_listener, + restricted_access=(details['channel'] == config['channel']) + ) + input_list.add(group) + logger.debug("Group '%s' created" % details['channel']) - return (input_list, data_sources) + group.add(input) + return (input_list, data_loader_list, data_sources) #---------------------------------------------------------- @@ -205,7 +271,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp cache_access=AccessMode.NONE, socket=None): data_sinks = [] - output_list = outputs.OutputList() + output_list = OutputList() # This is used for parallelization purposes start_index, end_index = config.get('range', (None, None)) @@ -254,7 +320,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp break (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \ - data.getAllFilenames(input_path) + getAllFilenames(input_path) end_indices = [ int(x.split('.')[-2]) for x in indices_filenames ] end_indices.sort() @@ -262,7 +328,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp start_index = 0 end_index = end_indices[-1] - data_sink = data.CachedDataSink() + data_sink = CachedDataSink() data_sinks.append(data_sink) status = data_sink.setup( @@ -276,9 +342,9 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp if not status: raise IOError("Cannot create cache sink '%s'" % details['path']) - output_list.add(outputs.Output(name, data_sink, - synchronization_listener=synchronization_listener, - force_start_index=start_index) + output_list.add(Output(name, data_sink, + synchronization_listener=synchronization_listener, + force_start_index=start_index) ) if 'result' not in config: @@ -292,9 +358,9 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp if socket is None: raise IOError("No socket provided for remote outputs") - output_list.add(outputs.RemoteOutput(name, dataformat, socket, - synchronization_listener=synchronization_listener, - force_start_index=start_index or 0) + output_list.add(RemoteOutput(name, dataformat, socket, + synchronization_listener=synchronization_listener, + force_start_index=start_index or 0) ) logger.debug("RemoteOutput '%s' created: group='%s', dataformat='%s'" % \ diff --git a/beat/backend/python/test/test_executor.py b/beat/backend/python/test/test_executor.py new file mode 100644 index 0000000..a155c4a --- /dev/null +++ b/beat/backend/python/test/test_executor.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +############################################################################### +# # +# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ # +# Contact: beat.support@idiap.ch # +# # +# This file is part of the beat.backend.python module of the BEAT platform. # +# # +# Commercial License Usage # +# Licensees holding valid commercial BEAT licenses may use this file in # +# accordance with the terms contained in a written agreement between you # +# and Idiap. For further information contact tto@idiap.ch # +# # +# Alternatively, this file may be used under the terms of the GNU Affero # +# Public License version 3 as published by the Free Software and appearing # +# in the file LICENSE.AGPL included in the packaging of this file. # +# The BEAT platform is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # +# or FITNESS FOR A PARTICULAR PURPOSE. # +# # +# You should have received a copy of the GNU Affero Public License along # +# with the BEAT platform. If not, see http://www.gnu.org/licenses/. # +# # +############################################################################### + + +import unittest +import tempfile +import simplejson +import os +import zmq +import shutil +import numpy as np +from copy import deepcopy + +from ..executor import Executor +from ..message_handler import MessageHandler +from ..inputs import InputList +from ..algorithm import Algorithm +from ..dataformat import DataFormat +from ..data import CachedDataSink +from ..data import CachedFileLoader +from ..helpers import convert_experiment_configuration_to_container +from ..helpers import create_inputs_from_configuration +from ..helpers import create_outputs_from_configuration +from ..helpers import AccessMode + +from . import prefix + + +CONFIGURATION = { + 'algorithm': '', + 'channel': 'main', + 'parameters': { + }, + 'inputs': { + 'in': { + 'path': 'INPUT', + 'channel': 'main', + } + }, + 'outputs': { + 'out': { + 'path': 'OUTPUT', + 'channel': 'main' + } + }, +} + + +#---------------------------------------------------------- + + +class TestExecutor(unittest.TestCase): + + def setUp(self): + self.cache_root = tempfile.mkdtemp(prefix=__name__) + self.working_dir = tempfile.mkdtemp(prefix=__name__) + self.message_handler = None + self.executor_socket = None + self.zmq_context = None + + + def tearDown(self): + shutil.rmtree(self.cache_root) + shutil.rmtree(self.working_dir) + + if self.message_handler is not None: + self.message_handler.kill() + self.message_handler.join() + self.message_handler.destroy() + self.message_handler = None + + if self.executor_socket is not None: + self.executor_socket.setsockopt(zmq.LINGER, 0) + self.executor_socket.close() + self.zmq_context.destroy() + self.executor_socket = None + self.zmq_context = None + + + def writeData(self, input_name, indices, start_value): + filename = os.path.join(self.cache_root, CONFIGURATION['inputs'][input_name]['path'] + '.data') + + dataformat = DataFormat(prefix, 'user/single_integer/1') + self.assertTrue(dataformat.valid) + + data_sink = CachedDataSink() + self.assertTrue(data_sink.setup(filename, dataformat, indices[0][0], indices[-1][1])) + + for i in indices: + data = dataformat.type() + data.value = np.int32(start_value + i[0]) + data_sink.write(data, i[0], i[1]) + + (nb_bytes, duration) = data_sink.statistics() + self.assertTrue(nb_bytes > 0) + self.assertTrue(duration > 0) + + data_sink.close() + del data_sink + + + def process(self, algorithm_name, proxy_mode=False): + self.writeData('in', [(0, 0), (1, 1), (2, 2), (3, 3)], 1000) + + config = deepcopy(CONFIGURATION) + config['algorithm'] = algorithm_name + config = convert_experiment_configuration_to_container(config, proxy_mode) + + with open(os.path.join(self.working_dir, 'configuration.json'), 'wb') as f: + simplejson.dump(config, f, indent=4) + + working_prefix = os.path.join(self.working_dir, 'prefix') + if not os.path.exists(working_prefix): + os.makedirs(working_prefix) + + algorithm = Algorithm(prefix, algorithm_name) + algorithm.export(working_prefix) + + if proxy_mode: + cache_access = AccessMode.LOCAL + + (input_list, _, data_sources) = create_inputs_from_configuration( + config, algorithm, prefix, self.cache_root, + cache_access=cache_access, + no_synchronisation_listeners=True + ) + + (output_list, data_sinks) = create_outputs_from_configuration( + config, algorithm, prefix, self.cache_root, + input_list, cache_access=cache_access + ) + + self.message_handler = MessageHandler('127.0.0.1', inputs=input_list, outputs=output_list) + else: + self.message_handler = MessageHandler('127.0.0.1') + + self.message_handler.start() + + self.zmq_context = zmq.Context() + self.executor_socket = self.zmq_context.socket(zmq.PAIR) + self.executor_socket.connect(self.message_handler.address) + + executor = Executor(self.executor_socket, self.working_dir, cache_root=self.cache_root) + + self.assertTrue(executor.setup()) + self.assertTrue(executor.prepare()) + self.assertTrue(executor.process()) + + if proxy_mode: + for output in output_list: + output.close() + + cached_file = CachedFileLoader() + self.assertTrue(cached_file.setup(os.path.join(self.cache_root, CONFIGURATION['outputs']['out']['path'] + '.data'), prefix)) + + for i in range(len(cached_file)): + data, start, end = cached_file[i] + self.assertEqual(data.value, 1000 + i) + self.assertEqual(start, i) + self.assertEqual(end, i) + + + def test_legacy_echo_1_local(self): + self.process('legacy/echo/1') + + + def test_legacy_echo_1_remote(self): + self.process('legacy/echo/1', True) + + + def test_sequential_echo_1(self): + self.process('sequential/echo/1') + + + def test_autonomous_echo_1(self): + self.process('autonomous/echo/1') -- GitLab