From a9164accbdbcc0a0792ee86459677e35a2c11953 Mon Sep 17 00:00:00 2001
From: Philip ABBET <philip.abbet@idiap.ch>
Date: Wed, 6 Dec 2017 17:22:16 +0100
Subject: [PATCH] Refactoring: the 'Executor' class now supports sequential and
 autonomous algorithms

---
 beat/backend/python/executor.py           | 108 ++++++++----
 beat/backend/python/helpers.py            | 182 +++++++++++++-------
 beat/backend/python/test/test_executor.py | 200 ++++++++++++++++++++++
 3 files changed, 398 insertions(+), 92 deletions(-)
 create mode 100644 beat/backend/python/test/test_executor.py

diff --git a/beat/backend/python/executor.py b/beat/backend/python/executor.py
index df0b929..7d7e865 100755
--- a/beat/backend/python/executor.py
+++ b/beat/backend/python/executor.py
@@ -37,13 +37,11 @@ import time
 import zmq
 import simplejson
 
-from . import algorithm
-from . import inputs
-from . import outputs
-from . import stats
+from .algorithm import Algorithm
 from .helpers import create_inputs_from_configuration
 from .helpers import create_outputs_from_configuration
 from .helpers import AccessMode
+from . import stats
 
 
 class Executor(object):
@@ -90,34 +88,47 @@ class Executor(object):
         self.prefix = os.path.join(directory, 'prefix')
         self.runner = None
 
-        # temporary caches, if the user has not set them, for performance
+        # Temporary caches, if the user has not set them, for performance
         database_cache = database_cache if database_cache is not None else {}
         dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
         library_cache = library_cache if library_cache is not None else {}
 
-        self.algorithm = algorithm.Algorithm(self.prefix, self.data['algorithm'],
-                dataformat_cache, library_cache)
+        # Load the algorithm
+        self.algorithm = Algorithm(self.prefix, self.data['algorithm'],
+                                   dataformat_cache, library_cache)
 
-        # Use algorithm names for inputs and outputs
         main_channel = self.data['channel']
 
-        # Loads algorithm inputs
-        if self.data['proxy_mode']:
-            cache_access = AccessMode.REMOTE
-        else:
-            cache_access = AccessMode.LOCAL
+        if self.algorithm.type == Algorithm.LEGACY:
+            # Loads algorithm inputs
+            if self.data['proxy_mode']:
+                cache_access = AccessMode.REMOTE
+            else:
+                cache_access = AccessMode.LOCAL
 
-        (self.input_list, _) = create_inputs_from_configuration(
-            self.data, self.algorithm, self.prefix, cache_root,
-            cache_access=cache_access, db_access=AccessMode.REMOTE,
-            socket=self.socket
-        )
+            (self.input_list, self.data_loaders, _) = create_inputs_from_configuration(
+                self.data, self.algorithm, self.prefix, cache_root,
+                cache_access=cache_access, db_access=AccessMode.REMOTE,
+                socket=self.socket
+            )
 
-        # Loads algorithm outputs
-        (self.output_list, _) = create_outputs_from_configuration(
-            self.data, self.algorithm, self.prefix, cache_root, self.input_list,
-            cache_access=cache_access, socket=self.socket
-        )
+            # Loads algorithm outputs
+            (self.output_list, _) = create_outputs_from_configuration(
+                self.data, self.algorithm, self.prefix, cache_root, self.input_list,
+                cache_access=cache_access, socket=self.socket
+            )
+
+        else:
+            (self.input_list, self.data_loaders, _) = create_inputs_from_configuration(
+                self.data, self.algorithm, self.prefix, cache_root,
+                cache_access=AccessMode.LOCAL, db_access=AccessMode.REMOTE
+            )
+
+            # Loads algorithm outputs
+            (self.output_list, _) = create_outputs_from_configuration(
+                self.data, self.algorithm, self.prefix, cache_root, self.input_list,
+                cache_access=AccessMode.LOCAL
+            )
 
 
     def setup(self):
@@ -129,27 +140,56 @@ class Executor(object):
         return retval
 
 
+    def prepare(self):
+        """Prepare the algorithm"""
+
+        self.runner = self.algorithm.runner()
+        retval = self.runner.prepare(self.data_loaders)
+        logger.debug("User algorithm is prepared")
+        return retval
+
+
     def process(self):
         """Executes the user algorithm code using the current interpreter.
         """
 
-        if not self.input_list or not self.output_list:
-            raise RuntimeError("I/O for execution block has not yet been set up")
-
-        while self.input_list.hasMoreData():
-            main_group = self.input_list.main_group
-            main_group.restricted_access = False
-            main_group.next()
-            main_group.restricted_access = True
-
+        if self.algorithm.type == Algorithm.AUTONOMOUS:
             if self.analysis:
-                result = self.runner.process(inputs=self.input_list, output=self.output_list[0])
+                result = self.runner.process(data_loaders=self.data_loaders,
+                                             output=self.output_list[0])
             else:
-                result = self.runner.process(inputs=self.input_list, outputs=self.output_list)
+                result = self.runner.process(data_loaders=self.data_loaders,
+                                             outputs=self.output_list)
 
             if not result:
                 return False
 
+        else:
+            while self.input_list.hasMoreData():
+                main_group = self.input_list.main_group
+                main_group.restricted_access = False
+                main_group.next()
+                main_group.restricted_access = True
+
+                if self.algorithm.type == Algorithm.LEGACY:
+                    if self.analysis:
+                        result = self.runner.process(inputs=self.input_list, output=self.output_list[0])
+                    else:
+                        result = self.runner.process(inputs=self.input_list, outputs=self.output_list)
+
+                elif self.algorithm.type == Algorithm.SEQUENTIAL:
+                    if self.analysis:
+                        result = self.runner.process(inputs=self.input_list,
+                                                     data_loaders=self.data_loaders,
+                                                     output=self.output_list[0])
+                    else:
+                        result = self.runner.process(inputs=self.input_list,
+                                                     data_loaders=self.data_loaders,
+                                                     outputs=self.output_list)
+
+                if not result:
+                    return False
+
         for output in self.output_list:
             output.close()
 
diff --git a/beat/backend/python/helpers.py b/beat/backend/python/helpers.py
index 9ac010d..e82bf19 100755
--- a/beat/backend/python/helpers.py
+++ b/beat/backend/python/helpers.py
@@ -32,11 +32,26 @@ import errno
 import logging
 logger = logging.getLogger(__name__)
 
-from . import data
-from . import inputs
-from . import outputs
+from .data import MemoryDataSource
+from .data import CachedDataSource
+from .data import CachedFileLoader
+from .data import CachedDataSink
+from .data import getAllFilenames
+from .data_loaders import DataLoaderList
+from .data_loaders import DataLoader
+from .inputs import InputList
+from .inputs import Input
+from .inputs import RemoteInput
+from .inputs import InputGroup
+from .outputs import SynchronizationListener
+from .outputs import OutputList
+from .outputs import Output
+from .outputs import RemoteOutput
+from .algorithm import Algorithm
 
 
+#----------------------------------------------------------
+
 
 def convert_experiment_configuration_to_container(config, proxy_mode):
     data = {
@@ -80,13 +95,77 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
 
     data_sources = []
     views = {}
-    input_list = inputs.InputList()
+    input_list = InputList()
+    data_loader_list = DataLoaderList()
 
     # This is used for parallelization purposes
     start_index, end_index = config.get('range', (None, None))
 
+
+    def _create_local_input(details):
+        data_source = CachedDataSource()
+        data_sources.append(data_source)
+
+        filename = os.path.join(cache_root, details['path'] + '.data')
+
+        if details['channel'] == config['channel']: # synchronized
+            status = data_source.setup(
+                      filename=filename,
+                      prefix=prefix,
+                      force_start_index=start_index,
+                      force_end_index=end_index,
+                      unpack=True,
+                     )
+        else:
+            status = data_source.setup(
+                      filename=filename,
+                      prefix=prefix,
+                      unpack=True,
+                     )
+
+        if not status:
+            raise IOError("cannot load cache file `%s'" % details['path'])
+
+        input = Input(name, algorithm.input_map[name], data_source)
+
+        logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
+                        (name, details['channel'], algorithm.input_map[name], filename))
+
+        return input
+
+
+    def _create_data_loader(details):
+        filename = os.path.join(cache_root, details['path'] + '.data')
+
+        data_loader = data_loader_list[details['channel']]
+        if data_loader is None:
+            data_loader = DataLoader(details['channel'])
+            data_loader_list.add(data_loader)
+
+            logger.debug("Data loader created: group='%s'" % details['channel'])
+
+        cached_file = CachedFileLoader()
+        result = cached_file.setup(
+            filename=filename,
+            prefix=prefix,
+            start_index=start_index,
+            end_index=end_index,
+            unpack=True,
+        )
+
+        if not result:
+            raise IOError("cannot load cache file `%s'" % details['path'])
+
+        data_loader.add(name, cached_file)
+
+        logger.debug("Input '%s' added to data loader: group='%s', dataformat='%s', filename='%s'" % \
+                        (name, details['channel'], algorithm.input_map[name], filename))
+
+
     for name, details in config['inputs'].items():
 
+        input = None
+
         if details.get('database', False):
             if db_access == AccessMode.LOCAL:
                 if databases is None:
@@ -114,12 +193,12 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
                     view = views[channel]
 
                 # Creation of the input
-                data_source = data.MemoryDataSource(view.done, next_callback=view.next)
+                data_source = MemoryDataSource(view.done, next_callback=view.next)
 
                 output = view.outputs[details['output']]
                 output.data_sink.data_sources.append(data_source)
 
-                input = inputs.Input(name, algorithm.input_map[name], data_source)
+                input = Input(name, algorithm.input_map[name], data_source)
 
                 logger.debug("Input '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
                                 (name, channel, algorithm.input_map[name], details['database'],
@@ -129,47 +208,34 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
                 if socket is None:
                     raise IOError("No socket provided for remote inputs")
 
-                input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
-                                           socket, unpack=unpack)
+                input = RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
+                                    socket, unpack=unpack)
 
                 logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s', connected to a database" % \
                                 (name, details['channel'], algorithm.input_map[name]))
 
+
         elif cache_access == AccessMode.LOCAL:
-            data_source = data.CachedDataSource()
-            data_sources.append(data_source)
-
-            filename = os.path.join(cache_root, details['path'] + '.data')
-
-            if details['channel'] == config['channel']: # synchronized
-                status = data_source.setup(
-                          filename=filename,
-                          prefix=prefix,
-                          force_start_index=start_index,
-                          force_end_index=end_index,
-                          unpack=True,
-                         )
-            else:
-                status = data_source.setup(
-                          filename=filename,
-                          prefix=prefix,
-                          unpack=True,
-                         )
 
-            if not status:
-                raise IOError("cannot load cache file `%s'" % details['path'])
+            if algorithm.type == Algorithm.LEGACY:
+                input = _create_local_input(details)
 
-            input = inputs.Input(name, algorithm.input_map[name], data_source)
+            elif algorithm.type == Algorithm.SEQUENTIAL:
+                if details['channel'] == config['channel']: # synchronized
+                    input = _create_local_input(details)
+                else:
+                    _create_data_loader(details)
+
+            elif algorithm.type == Algorithm.AUTONOMOUS:
+                _create_data_loader(details)
 
-            logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
-                            (name, details['channel'], algorithm.input_map[name], filename))
 
         elif cache_access == AccessMode.REMOTE:
             if socket is None:
                 raise IOError("No socket provided for remote inputs")
 
-            input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
-                                       socket, unpack=unpack)
+            input = RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
+                                socket, unpack=unpack)
 
             logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s'" % \
                             (name, details['channel'], algorithm.input_map[name]))
@@ -178,24 +244,24 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
             continue
 
         # Synchronization bits
-        group = input_list.group(details['channel'])
-        if group is None:
-            synchronization_listener = None
-            if not no_synchronisation_listeners:
-                synchronization_listener = outputs.SynchronizationListener()
-
-            group = inputs.InputGroup(
-                      details['channel'],
-                      synchronization_listener=synchronization_listener,
-                      restricted_access=(details['channel'] == config['channel'])
-                    )
-            input_list.add(group)
-            logger.debug("Group '%s' created" % details['channel'])
+        if input is not None:
+            group = input_list.group(details['channel'])
+            if group is None:
+                synchronization_listener = None
+                if not no_synchronisation_listeners:
+                    synchronization_listener = SynchronizationListener()
 
-        group.add(input)
+                group = InputGroup(
+                          details['channel'],
+                          synchronization_listener=synchronization_listener,
+                          restricted_access=(details['channel'] == config['channel'])
+                        )
+                input_list.add(group)
+                logger.debug("Group '%s' created" % details['channel'])
 
-    return (input_list, data_sources)
+            group.add(input)
 
+    return (input_list, data_loader_list, data_sources)
 
 
 #----------------------------------------------------------
@@ -205,7 +271,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
                                       cache_access=AccessMode.NONE, socket=None):
 
     data_sinks = []
-    output_list = outputs.OutputList()
+    output_list = OutputList()
 
     # This is used for parallelization purposes
     start_index, end_index = config.get('range', (None, None))
@@ -254,7 +320,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
                         break
 
                 (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \
-                        data.getAllFilenames(input_path)
+                        getAllFilenames(input_path)
 
                 end_indices = [ int(x.split('.')[-2]) for x in indices_filenames ]
                 end_indices.sort()
@@ -262,7 +328,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
                 start_index = 0
                 end_index = end_indices[-1]
 
-            data_sink = data.CachedDataSink()
+            data_sink = CachedDataSink()
             data_sinks.append(data_sink)
 
             status = data_sink.setup(
@@ -276,9 +342,9 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
             if not status:
                 raise IOError("Cannot create cache sink '%s'" % details['path'])
 
-            output_list.add(outputs.Output(name, data_sink,
-                synchronization_listener=synchronization_listener,
-                force_start_index=start_index)
+            output_list.add(Output(name, data_sink,
+                                   synchronization_listener=synchronization_listener,
+                                   force_start_index=start_index)
             )
 
             if 'result' not in config:
@@ -292,9 +358,9 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
             if socket is None:
                 raise IOError("No socket provided for remote outputs")
 
-            output_list.add(outputs.RemoteOutput(name, dataformat, socket,
-                synchronization_listener=synchronization_listener,
-                force_start_index=start_index or 0)
+            output_list.add(RemoteOutput(name, dataformat, socket,
+                                         synchronization_listener=synchronization_listener,
+                                         force_start_index=start_index or 0)
             )
 
             logger.debug("RemoteOutput '%s' created: group='%s', dataformat='%s'" % \
diff --git a/beat/backend/python/test/test_executor.py b/beat/backend/python/test/test_executor.py
new file mode 100644
index 0000000..a155c4a
--- /dev/null
+++ b/beat/backend/python/test/test_executor.py
@@ -0,0 +1,200 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+
+###############################################################################
+#                                                                             #
+# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/           #
+# Contact: beat.support@idiap.ch                                              #
+#                                                                             #
+# This file is part of the beat.backend.python module of the BEAT platform.   #
+#                                                                             #
+# Commercial License Usage                                                    #
+# Licensees holding valid commercial BEAT licenses may use this file in       #
+# accordance with the terms contained in a written agreement between you      #
+# and Idiap. For further information contact tto@idiap.ch                     #
+#                                                                             #
+# Alternatively, this file may be used under the terms of the GNU Affero      #
+# Public License version 3 as published by the Free Software and appearing    #
+# in the file LICENSE.AGPL included in the packaging of this file.            #
+# The BEAT platform is distributed in the hope that it will be useful, but    #
+# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
+# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
+#                                                                             #
+# You should have received a copy of the GNU Affero Public License along      #
+# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
+#                                                                             #
+###############################################################################
+
+
+import unittest
+import tempfile
+import simplejson
+import os
+import zmq
+import shutil
+import numpy as np
+from copy import deepcopy
+
+from ..executor import Executor
+from ..message_handler import MessageHandler
+from ..inputs import InputList
+from ..algorithm import Algorithm
+from ..dataformat import DataFormat
+from ..data import CachedDataSink
+from ..data import CachedFileLoader
+from ..helpers import convert_experiment_configuration_to_container
+from ..helpers import create_inputs_from_configuration
+from ..helpers import create_outputs_from_configuration
+from ..helpers import AccessMode
+
+from . import prefix
+
+
+CONFIGURATION = {
+    'algorithm': '',
+    'channel': 'main',
+    'parameters': {
+    },
+    'inputs': {
+        'in': {
+            'path': 'INPUT',
+            'channel': 'main',
+        }
+    },
+    'outputs': {
+        'out': {
+            'path': 'OUTPUT',
+            'channel': 'main'
+        }
+    },
+}
+
+
+#----------------------------------------------------------
+
+
+class TestExecutor(unittest.TestCase):
+
+    def setUp(self):
+        self.cache_root = tempfile.mkdtemp(prefix=__name__)
+        self.working_dir = tempfile.mkdtemp(prefix=__name__)
+        self.message_handler = None
+        self.executor_socket = None
+        self.zmq_context = None
+
+
+    def tearDown(self):
+        shutil.rmtree(self.cache_root)
+        shutil.rmtree(self.working_dir)
+
+        if self.message_handler is not None:
+            self.message_handler.kill()
+            self.message_handler.join()
+            self.message_handler.destroy()
+            self.message_handler = None
+
+        if self.executor_socket is not None:
+            self.executor_socket.setsockopt(zmq.LINGER, 0)
+            self.executor_socket.close()
+            self.zmq_context.destroy()
+            self.executor_socket = None
+            self.zmq_context = None
+
+
+    def writeData(self, input_name, indices, start_value):
+        filename = os.path.join(self.cache_root, CONFIGURATION['inputs'][input_name]['path'] + '.data')
+
+        dataformat = DataFormat(prefix, 'user/single_integer/1')
+        self.assertTrue(dataformat.valid)
+
+        data_sink = CachedDataSink()
+        self.assertTrue(data_sink.setup(filename, dataformat, indices[0][0], indices[-1][1]))
+
+        for i in indices:
+            data = dataformat.type()
+            data.value = np.int32(start_value + i[0])
+            data_sink.write(data, i[0], i[1])
+
+        (nb_bytes, duration) = data_sink.statistics()
+        self.assertTrue(nb_bytes > 0)
+        self.assertTrue(duration > 0)
+
+        data_sink.close()
+        del data_sink
+
+
+    def process(self, algorithm_name, proxy_mode=False):
+        self.writeData('in', [(0, 0), (1, 1), (2, 2), (3, 3)], 1000)
+
+        config = deepcopy(CONFIGURATION)
+        config['algorithm'] = algorithm_name
+        config = convert_experiment_configuration_to_container(config, proxy_mode)
+
+        with open(os.path.join(self.working_dir, 'configuration.json'), 'wb') as f:
+            simplejson.dump(config, f, indent=4)
+
+        working_prefix = os.path.join(self.working_dir, 'prefix')
+        if not os.path.exists(working_prefix):
+            os.makedirs(working_prefix)
+
+        algorithm = Algorithm(prefix, algorithm_name)
+        algorithm.export(working_prefix)
+
+        if proxy_mode:
+            cache_access = AccessMode.LOCAL
+
+            (input_list, _, data_sources) = create_inputs_from_configuration(
+                                                config, algorithm, prefix, self.cache_root,
+                                                cache_access=cache_access,
+                                                no_synchronisation_listeners=True
+            )
+
+            (output_list, data_sinks) = create_outputs_from_configuration(
+                                                config, algorithm, prefix, self.cache_root,
+                                                input_list, cache_access=cache_access
+            )
+
+            self.message_handler = MessageHandler('127.0.0.1', inputs=input_list, outputs=output_list)
+        else:
+            self.message_handler = MessageHandler('127.0.0.1')
+
+        self.message_handler.start()
+
+        self.zmq_context = zmq.Context()
+        self.executor_socket = self.zmq_context.socket(zmq.PAIR)
+        self.executor_socket.connect(self.message_handler.address)
+
+        executor = Executor(self.executor_socket, self.working_dir, cache_root=self.cache_root)
+
+        self.assertTrue(executor.setup())
+        self.assertTrue(executor.prepare())
+        self.assertTrue(executor.process())
+
+        if proxy_mode:
+            for output in output_list:
+                output.close()
+
+        cached_file = CachedFileLoader()
+        self.assertTrue(cached_file.setup(os.path.join(self.cache_root, CONFIGURATION['outputs']['out']['path'] + '.data'), prefix))
+
+        for i in range(len(cached_file)):
+            data, start, end = cached_file[i]
+            self.assertEqual(data.value, 1000 + i)
+            self.assertEqual(start, i)
+            self.assertEqual(end, i)
+
+
+    def test_legacy_echo_1_local(self):
+        self.process('legacy/echo/1')
+
+
+    def test_legacy_echo_1_remote(self):
+        self.process('legacy/echo/1', True)
+
+
+    def test_sequential_echo_1(self):
+        self.process('sequential/echo/1')
+
+
+    def test_autonomous_echo_1(self):
+        self.process('autonomous/echo/1')
-- 
GitLab