Commit a9164acc authored by Philip ABBET's avatar Philip ABBET

Refactoring: the 'Executor' class now supports sequential and autonomous algorithms

parent 078cf0fd
......@@ -37,13 +37,11 @@ import time
import zmq
import simplejson
from . import algorithm
from . import inputs
from . import outputs
from . import stats
from .algorithm import Algorithm
from .helpers import create_inputs_from_configuration
from .helpers import create_outputs_from_configuration
from .helpers import AccessMode
from . import stats
class Executor(object):
......@@ -90,34 +88,47 @@ class Executor(object):
self.prefix = os.path.join(directory, 'prefix')
self.runner = None
# temporary caches, if the user has not set them, for performance
# Temporary caches, if the user has not set them, for performance
database_cache = database_cache if database_cache is not None else {}
dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
library_cache = library_cache if library_cache is not None else {}
self.algorithm = algorithm.Algorithm(self.prefix, self.data['algorithm'],
dataformat_cache, library_cache)
# Load the algorithm
self.algorithm = Algorithm(self.prefix, self.data['algorithm'],
dataformat_cache, library_cache)
# Use algorithm names for inputs and outputs
main_channel = self.data['channel']
# Loads algorithm inputs
if self.data['proxy_mode']:
cache_access = AccessMode.REMOTE
else:
cache_access = AccessMode.LOCAL
if self.algorithm.type == Algorithm.LEGACY:
# Loads algorithm inputs
if self.data['proxy_mode']:
cache_access = AccessMode.REMOTE
else:
cache_access = AccessMode.LOCAL
(self.input_list, _) = create_inputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root,
cache_access=cache_access, db_access=AccessMode.REMOTE,
socket=self.socket
)
(self.input_list, self.data_loaders, _) = create_inputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root,
cache_access=cache_access, db_access=AccessMode.REMOTE,
socket=self.socket
)
# Loads algorithm outputs
(self.output_list, _) = create_outputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root, self.input_list,
cache_access=cache_access, socket=self.socket
)
# Loads algorithm outputs
(self.output_list, _) = create_outputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root, self.input_list,
cache_access=cache_access, socket=self.socket
)
else:
(self.input_list, self.data_loaders, _) = create_inputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root,
cache_access=AccessMode.LOCAL, db_access=AccessMode.REMOTE
)
# Loads algorithm outputs
(self.output_list, _) = create_outputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root, self.input_list,
cache_access=AccessMode.LOCAL
)
def setup(self):
......@@ -129,27 +140,56 @@ class Executor(object):
return retval
def prepare(self):
"""Prepare the algorithm"""
self.runner = self.algorithm.runner()
retval = self.runner.prepare(self.data_loaders)
logger.debug("User algorithm is prepared")
return retval
def process(self):
"""Executes the user algorithm code using the current interpreter.
"""
if not self.input_list or not self.output_list:
raise RuntimeError("I/O for execution block has not yet been set up")
while self.input_list.hasMoreData():
main_group = self.input_list.main_group
main_group.restricted_access = False
main_group.next()
main_group.restricted_access = True
if self.algorithm.type == Algorithm.AUTONOMOUS:
if self.analysis:
result = self.runner.process(inputs=self.input_list, output=self.output_list[0])
result = self.runner.process(data_loaders=self.data_loaders,
output=self.output_list[0])
else:
result = self.runner.process(inputs=self.input_list, outputs=self.output_list)
result = self.runner.process(data_loaders=self.data_loaders,
outputs=self.output_list)
if not result:
return False
else:
while self.input_list.hasMoreData():
main_group = self.input_list.main_group
main_group.restricted_access = False
main_group.next()
main_group.restricted_access = True
if self.algorithm.type == Algorithm.LEGACY:
if self.analysis:
result = self.runner.process(inputs=self.input_list, output=self.output_list[0])
else:
result = self.runner.process(inputs=self.input_list, outputs=self.output_list)
elif self.algorithm.type == Algorithm.SEQUENTIAL:
if self.analysis:
result = self.runner.process(inputs=self.input_list,
data_loaders=self.data_loaders,
output=self.output_list[0])
else:
result = self.runner.process(inputs=self.input_list,
data_loaders=self.data_loaders,
outputs=self.output_list)
if not result:
return False
for output in self.output_list:
output.close()
......
......@@ -32,11 +32,26 @@ import errno
import logging
logger = logging.getLogger(__name__)
from . import data
from . import inputs
from . import outputs
from .data import MemoryDataSource
from .data import CachedDataSource
from .data import CachedFileLoader
from .data import CachedDataSink
from .data import getAllFilenames
from .data_loaders import DataLoaderList
from .data_loaders import DataLoader
from .inputs import InputList
from .inputs import Input
from .inputs import RemoteInput
from .inputs import InputGroup
from .outputs import SynchronizationListener
from .outputs import OutputList
from .outputs import Output
from .outputs import RemoteOutput
from .algorithm import Algorithm
#----------------------------------------------------------
def convert_experiment_configuration_to_container(config, proxy_mode):
data = {
......@@ -80,13 +95,77 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
data_sources = []
views = {}
input_list = inputs.InputList()
input_list = InputList()
data_loader_list = DataLoaderList()
# This is used for parallelization purposes
start_index, end_index = config.get('range', (None, None))
def _create_local_input(details):
data_source = CachedDataSource()
data_sources.append(data_source)
filename = os.path.join(cache_root, details['path'] + '.data')
if details['channel'] == config['channel']: # synchronized
status = data_source.setup(
filename=filename,
prefix=prefix,
force_start_index=start_index,
force_end_index=end_index,
unpack=True,
)
else:
status = data_source.setup(
filename=filename,
prefix=prefix,
unpack=True,
)
if not status:
raise IOError("cannot load cache file `%s'" % details['path'])
input = Input(name, algorithm.input_map[name], data_source)
logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
(name, details['channel'], algorithm.input_map[name], filename))
return input
def _create_data_loader(details):
filename = os.path.join(cache_root, details['path'] + '.data')
data_loader = data_loader_list[details['channel']]
if data_loader is None:
data_loader = DataLoader(details['channel'])
data_loader_list.add(data_loader)
logger.debug("Data loader created: group='%s'" % details['channel'])
cached_file = CachedFileLoader()
result = cached_file.setup(
filename=filename,
prefix=prefix,
start_index=start_index,
end_index=end_index,
unpack=True,
)
if not result:
raise IOError("cannot load cache file `%s'" % details['path'])
data_loader.add(name, cached_file)
logger.debug("Input '%s' added to data loader: group='%s', dataformat='%s', filename='%s'" % \
(name, details['channel'], algorithm.input_map[name], filename))
for name, details in config['inputs'].items():
input = None
if details.get('database', False):
if db_access == AccessMode.LOCAL:
if databases is None:
......@@ -114,12 +193,12 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
view = views[channel]
# Creation of the input
data_source = data.MemoryDataSource(view.done, next_callback=view.next)
data_source = MemoryDataSource(view.done, next_callback=view.next)
output = view.outputs[details['output']]
output.data_sink.data_sources.append(data_source)
input = inputs.Input(name, algorithm.input_map[name], data_source)
input = Input(name, algorithm.input_map[name], data_source)
logger.debug("Input '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
(name, channel, algorithm.input_map[name], details['database'],
......@@ -129,47 +208,34 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
if socket is None:
raise IOError("No socket provided for remote inputs")
input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
socket, unpack=unpack)
input = RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
socket, unpack=unpack)
logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s', connected to a database" % \
(name, details['channel'], algorithm.input_map[name]))
elif cache_access == AccessMode.LOCAL:
data_source = data.CachedDataSource()
data_sources.append(data_source)
filename = os.path.join(cache_root, details['path'] + '.data')
if details['channel'] == config['channel']: # synchronized
status = data_source.setup(
filename=filename,
prefix=prefix,
force_start_index=start_index,
force_end_index=end_index,
unpack=True,
)
else:
status = data_source.setup(
filename=filename,
prefix=prefix,
unpack=True,
)
if not status:
raise IOError("cannot load cache file `%s'" % details['path'])
if algorithm.type == Algorithm.LEGACY:
input = _create_local_input(details)
input = inputs.Input(name, algorithm.input_map[name], data_source)
elif algorithm.type == Algorithm.SEQUENTIAL:
if details['channel'] == config['channel']: # synchronized
input = _create_local_input(details)
else:
_create_data_loader(details)
elif algorithm.type == Algorithm.AUTONOMOUS:
_create_data_loader(details)
logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
(name, details['channel'], algorithm.input_map[name], filename))
elif cache_access == AccessMode.REMOTE:
if socket is None:
raise IOError("No socket provided for remote inputs")
input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
socket, unpack=unpack)
input = RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
socket, unpack=unpack)
logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s'" % \
(name, details['channel'], algorithm.input_map[name]))
......@@ -178,24 +244,24 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
continue
# Synchronization bits
group = input_list.group(details['channel'])
if group is None:
synchronization_listener = None
if not no_synchronisation_listeners:
synchronization_listener = outputs.SynchronizationListener()
group = inputs.InputGroup(
details['channel'],
synchronization_listener=synchronization_listener,
restricted_access=(details['channel'] == config['channel'])
)
input_list.add(group)
logger.debug("Group '%s' created" % details['channel'])
if input is not None:
group = input_list.group(details['channel'])
if group is None:
synchronization_listener = None
if not no_synchronisation_listeners:
synchronization_listener = SynchronizationListener()
group.add(input)
group = InputGroup(
details['channel'],
synchronization_listener=synchronization_listener,
restricted_access=(details['channel'] == config['channel'])
)
input_list.add(group)
logger.debug("Group '%s' created" % details['channel'])
return (input_list, data_sources)
group.add(input)
return (input_list, data_loader_list, data_sources)
#----------------------------------------------------------
......@@ -205,7 +271,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
cache_access=AccessMode.NONE, socket=None):
data_sinks = []
output_list = outputs.OutputList()
output_list = OutputList()
# This is used for parallelization purposes
start_index, end_index = config.get('range', (None, None))
......@@ -254,7 +320,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
break
(data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \
data.getAllFilenames(input_path)
getAllFilenames(input_path)
end_indices = [ int(x.split('.')[-2]) for x in indices_filenames ]
end_indices.sort()
......@@ -262,7 +328,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
start_index = 0
end_index = end_indices[-1]
data_sink = data.CachedDataSink()
data_sink = CachedDataSink()
data_sinks.append(data_sink)
status = data_sink.setup(
......@@ -276,9 +342,9 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
if not status:
raise IOError("Cannot create cache sink '%s'" % details['path'])
output_list.add(outputs.Output(name, data_sink,
synchronization_listener=synchronization_listener,
force_start_index=start_index)
output_list.add(Output(name, data_sink,
synchronization_listener=synchronization_listener,
force_start_index=start_index)
)
if 'result' not in config:
......@@ -292,9 +358,9 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
if socket is None:
raise IOError("No socket provided for remote outputs")
output_list.add(outputs.RemoteOutput(name, dataformat, socket,
synchronization_listener=synchronization_listener,
force_start_index=start_index or 0)
output_list.add(RemoteOutput(name, dataformat, socket,
synchronization_listener=synchronization_listener,
force_start_index=start_index or 0)
)
logger.debug("RemoteOutput '%s' created: group='%s', dataformat='%s'" % \
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.backend.python module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
import unittest
import tempfile
import simplejson
import os
import zmq
import shutil
import numpy as np
from copy import deepcopy
from ..executor import Executor
from ..message_handler import MessageHandler
from ..inputs import InputList
from ..algorithm import Algorithm
from ..dataformat import DataFormat
from ..data import CachedDataSink
from ..data import CachedFileLoader
from ..helpers import convert_experiment_configuration_to_container
from ..helpers import create_inputs_from_configuration
from ..helpers import create_outputs_from_configuration
from ..helpers import AccessMode
from . import prefix
CONFIGURATION = {
'algorithm': '',
'channel': 'main',
'parameters': {
},
'inputs': {
'in': {
'path': 'INPUT',
'channel': 'main',
}
},
'outputs': {
'out': {
'path': 'OUTPUT',
'channel': 'main'
}
},
}
#----------------------------------------------------------
class TestExecutor(unittest.TestCase):
def setUp(self):
self.cache_root = tempfile.mkdtemp(prefix=__name__)
self.working_dir = tempfile.mkdtemp(prefix=__name__)
self.message_handler = None
self.executor_socket = None
self.zmq_context = None
def tearDown(self):
shutil.rmtree(self.cache_root)
shutil.rmtree(self.working_dir)
if self.message_handler is not None:
self.message_handler.kill()
self.message_handler.join()
self.message_handler.destroy()
self.message_handler = None
if self.executor_socket is not None:
self.executor_socket.setsockopt(zmq.LINGER, 0)
self.executor_socket.close()
self.zmq_context.destroy()
self.executor_socket = None
self.zmq_context = None
def writeData(self, input_name, indices, start_value):
filename = os.path.join(self.cache_root, CONFIGURATION['inputs'][input_name]['path'] + '.data')
dataformat = DataFormat(prefix, 'user/single_integer/1')
self.assertTrue(dataformat.valid)
data_sink = CachedDataSink()
self.assertTrue(data_sink.setup(filename, dataformat, indices[0][0], indices[-1][1]))
for i in indices:
data = dataformat.type()
data.value = np.int32(start_value + i[0])
data_sink.write(data, i[0], i[1])
(nb_bytes, duration) = data_sink.statistics()
self.assertTrue(nb_bytes > 0)
self.assertTrue(duration > 0)
data_sink.close()
del data_sink
def process(self, algorithm_name, proxy_mode=False):
self.writeData('in', [(0, 0), (1, 1), (2, 2), (3, 3)], 1000)
config = deepcopy(CONFIGURATION)
config['algorithm'] = algorithm_name
config = convert_experiment_configuration_to_container(config, proxy_mode)
with open(os.path.join(self.working_dir, 'configuration.json'), 'wb') as f:
simplejson.dump(config, f, indent=4)
working_prefix = os.path.join(self.working_dir, 'prefix')
if not os.path.exists(working_prefix):
os.makedirs(working_prefix)
algorithm = Algorithm(prefix, algorithm_name)
algorithm.export(working_prefix)
if proxy_mode:
cache_access = AccessMode.LOCAL
(input_list, _, data_sources) = create_inputs_from_configuration(
config, algorithm, prefix, self.cache_root,
cache_access=cache_access,
no_synchronisation_listeners=True
)
(output_list, data_sinks) = create_outputs_from_configuration(
config, algorithm, prefix, self.cache_root,
input_list, cache_access=cache_access
)
self.message_handler = MessageHandler('127.0.0.1', inputs=input_list, outputs=output_list)
else:
self.message_handler = MessageHandler('127.0.0.1')
self.message_handler.start()
self.zmq_context = zmq.Context()
self.executor_socket = self.zmq_context.socket(zmq.PAIR)
self.executor_socket.connect(self.message_handler.address)
executor = Executor(self.executor_socket, self.working_dir, cache_root=self.cache_root)
self.assertTrue(executor.setup())
self.assertTrue(executor.prepare())
self.assertTrue(executor.process())
if proxy_mode:
for output in output_list:
output.close()
cached_file = CachedFileLoader()
self.assertTrue(cached_file.setup(os.path.join(self.cache_root, CONFIGURATION['outputs']['out']['path'] + '.data'), prefix))
for i in range(len(cached_file)):
data, start, end = cached_file[i]
self.assertEqual(data.value, 1000 + i)
self.assertEqual(start, i)
self.assertEqual(end, i)
def test_legacy_echo_1_local(self):
self.process('legacy/echo/1')
def test_legacy_echo_1_remote(self):
self.process('legacy/echo/1', True)
def test_sequential_echo_1(self):
self.process('sequential/echo/1')
def test_autonomous_echo_1(self):
self.process('autonomous/echo/1')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment