Commit a9164acc authored by Philip ABBET's avatar Philip ABBET

Refactoring: the 'Executor' class now supports sequential and autonomous algorithms

parent 078cf0fd
......@@ -37,13 +37,11 @@ import time
import zmq
import simplejson
from . import algorithm
from . import inputs
from . import outputs
from . import stats
from .algorithm import Algorithm
from .helpers import create_inputs_from_configuration
from .helpers import create_outputs_from_configuration
from .helpers import AccessMode
from . import stats
class Executor(object):
......@@ -90,24 +88,25 @@ class Executor(object):
self.prefix = os.path.join(directory, 'prefix')
self.runner = None
# temporary caches, if the user has not set them, for performance
# Temporary caches, if the user has not set them, for performance
database_cache = database_cache if database_cache is not None else {}
dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
library_cache = library_cache if library_cache is not None else {}
self.algorithm = algorithm.Algorithm(self.prefix, self.data['algorithm'],
# Load the algorithm
self.algorithm = Algorithm(self.prefix, self.data['algorithm'],
dataformat_cache, library_cache)
# Use algorithm names for inputs and outputs
main_channel = self.data['channel']
if self.algorithm.type == Algorithm.LEGACY:
# Loads algorithm inputs
if self.data['proxy_mode']:
cache_access = AccessMode.REMOTE
else:
cache_access = AccessMode.LOCAL
(self.input_list, _) = create_inputs_from_configuration(
(self.input_list, self.data_loaders, _) = create_inputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root,
cache_access=cache_access, db_access=AccessMode.REMOTE,
socket=self.socket
......@@ -119,6 +118,18 @@ class Executor(object):
cache_access=cache_access, socket=self.socket
)
else:
(self.input_list, self.data_loaders, _) = create_inputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root,
cache_access=AccessMode.LOCAL, db_access=AccessMode.REMOTE
)
# Loads algorithm outputs
(self.output_list, _) = create_outputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root, self.input_list,
cache_access=AccessMode.LOCAL
)
def setup(self):
"""Sets up the algorithm to start processing"""
......@@ -129,24 +140,53 @@ class Executor(object):
return retval
def prepare(self):
"""Prepare the algorithm"""
self.runner = self.algorithm.runner()
retval = self.runner.prepare(self.data_loaders)
logger.debug("User algorithm is prepared")
return retval
def process(self):
"""Executes the user algorithm code using the current interpreter.
"""
if not self.input_list or not self.output_list:
raise RuntimeError("I/O for execution block has not yet been set up")
if self.algorithm.type == Algorithm.AUTONOMOUS:
if self.analysis:
result = self.runner.process(data_loaders=self.data_loaders,
output=self.output_list[0])
else:
result = self.runner.process(data_loaders=self.data_loaders,
outputs=self.output_list)
if not result:
return False
else:
while self.input_list.hasMoreData():
main_group = self.input_list.main_group
main_group.restricted_access = False
main_group.next()
main_group.restricted_access = True
if self.algorithm.type == Algorithm.LEGACY:
if self.analysis:
result = self.runner.process(inputs=self.input_list, output=self.output_list[0])
else:
result = self.runner.process(inputs=self.input_list, outputs=self.output_list)
elif self.algorithm.type == Algorithm.SEQUENTIAL:
if self.analysis:
result = self.runner.process(inputs=self.input_list,
data_loaders=self.data_loaders,
output=self.output_list[0])
else:
result = self.runner.process(inputs=self.input_list,
data_loaders=self.data_loaders,
outputs=self.output_list)
if not result:
return False
......
This diff is collapsed.
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.backend.python module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
import unittest
import tempfile
import simplejson
import os
import zmq
import shutil
import numpy as np
from copy import deepcopy
from ..executor import Executor
from ..message_handler import MessageHandler
from ..inputs import InputList
from ..algorithm import Algorithm
from ..dataformat import DataFormat
from ..data import CachedDataSink
from ..data import CachedFileLoader
from ..helpers import convert_experiment_configuration_to_container
from ..helpers import create_inputs_from_configuration
from ..helpers import create_outputs_from_configuration
from ..helpers import AccessMode
from . import prefix
CONFIGURATION = {
'algorithm': '',
'channel': 'main',
'parameters': {
},
'inputs': {
'in': {
'path': 'INPUT',
'channel': 'main',
}
},
'outputs': {
'out': {
'path': 'OUTPUT',
'channel': 'main'
}
},
}
#----------------------------------------------------------
class TestExecutor(unittest.TestCase):
def setUp(self):
self.cache_root = tempfile.mkdtemp(prefix=__name__)
self.working_dir = tempfile.mkdtemp(prefix=__name__)
self.message_handler = None
self.executor_socket = None
self.zmq_context = None
def tearDown(self):
shutil.rmtree(self.cache_root)
shutil.rmtree(self.working_dir)
if self.message_handler is not None:
self.message_handler.kill()
self.message_handler.join()
self.message_handler.destroy()
self.message_handler = None
if self.executor_socket is not None:
self.executor_socket.setsockopt(zmq.LINGER, 0)
self.executor_socket.close()
self.zmq_context.destroy()
self.executor_socket = None
self.zmq_context = None
def writeData(self, input_name, indices, start_value):
filename = os.path.join(self.cache_root, CONFIGURATION['inputs'][input_name]['path'] + '.data')
dataformat = DataFormat(prefix, 'user/single_integer/1')
self.assertTrue(dataformat.valid)
data_sink = CachedDataSink()
self.assertTrue(data_sink.setup(filename, dataformat, indices[0][0], indices[-1][1]))
for i in indices:
data = dataformat.type()
data.value = np.int32(start_value + i[0])
data_sink.write(data, i[0], i[1])
(nb_bytes, duration) = data_sink.statistics()
self.assertTrue(nb_bytes > 0)
self.assertTrue(duration > 0)
data_sink.close()
del data_sink
def process(self, algorithm_name, proxy_mode=False):
self.writeData('in', [(0, 0), (1, 1), (2, 2), (3, 3)], 1000)
config = deepcopy(CONFIGURATION)
config['algorithm'] = algorithm_name
config = convert_experiment_configuration_to_container(config, proxy_mode)
with open(os.path.join(self.working_dir, 'configuration.json'), 'wb') as f:
simplejson.dump(config, f, indent=4)
working_prefix = os.path.join(self.working_dir, 'prefix')
if not os.path.exists(working_prefix):
os.makedirs(working_prefix)
algorithm = Algorithm(prefix, algorithm_name)
algorithm.export(working_prefix)
if proxy_mode:
cache_access = AccessMode.LOCAL
(input_list, _, data_sources) = create_inputs_from_configuration(
config, algorithm, prefix, self.cache_root,
cache_access=cache_access,
no_synchronisation_listeners=True
)
(output_list, data_sinks) = create_outputs_from_configuration(
config, algorithm, prefix, self.cache_root,
input_list, cache_access=cache_access
)
self.message_handler = MessageHandler('127.0.0.1', inputs=input_list, outputs=output_list)
else:
self.message_handler = MessageHandler('127.0.0.1')
self.message_handler.start()
self.zmq_context = zmq.Context()
self.executor_socket = self.zmq_context.socket(zmq.PAIR)
self.executor_socket.connect(self.message_handler.address)
executor = Executor(self.executor_socket, self.working_dir, cache_root=self.cache_root)
self.assertTrue(executor.setup())
self.assertTrue(executor.prepare())
self.assertTrue(executor.process())
if proxy_mode:
for output in output_list:
output.close()
cached_file = CachedFileLoader()
self.assertTrue(cached_file.setup(os.path.join(self.cache_root, CONFIGURATION['outputs']['out']['path'] + '.data'), prefix))
for i in range(len(cached_file)):
data, start, end = cached_file[i]
self.assertEqual(data.value, 1000 + i)
self.assertEqual(start, i)
self.assertEqual(end, i)
def test_legacy_echo_1_local(self):
self.process('legacy/echo/1')
def test_legacy_echo_1_remote(self):
self.process('legacy/echo/1', True)
def test_sequential_echo_1(self):
self.process('sequential/echo/1')
def test_autonomous_echo_1(self):
self.process('autonomous/echo/1')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment