Commit a9164acc authored by Philip ABBET's avatar Philip ABBET
Browse files

Refactoring: the 'Executor' class now supports sequential and autonomous algorithms

parent 078cf0fd
...@@ -37,13 +37,11 @@ import time ...@@ -37,13 +37,11 @@ import time
import zmq import zmq
import simplejson import simplejson
from . import algorithm from .algorithm import Algorithm
from . import inputs
from . import outputs
from . import stats
from .helpers import create_inputs_from_configuration from .helpers import create_inputs_from_configuration
from .helpers import create_outputs_from_configuration from .helpers import create_outputs_from_configuration
from .helpers import AccessMode from .helpers import AccessMode
from . import stats
class Executor(object): class Executor(object):
...@@ -90,34 +88,47 @@ class Executor(object): ...@@ -90,34 +88,47 @@ class Executor(object):
self.prefix = os.path.join(directory, 'prefix') self.prefix = os.path.join(directory, 'prefix')
self.runner = None self.runner = None
# temporary caches, if the user has not set them, for performance # Temporary caches, if the user has not set them, for performance
database_cache = database_cache if database_cache is not None else {} database_cache = database_cache if database_cache is not None else {}
dataformat_cache = dataformat_cache if dataformat_cache is not None else {} dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
library_cache = library_cache if library_cache is not None else {} library_cache = library_cache if library_cache is not None else {}
self.algorithm = algorithm.Algorithm(self.prefix, self.data['algorithm'], # Load the algorithm
dataformat_cache, library_cache) self.algorithm = Algorithm(self.prefix, self.data['algorithm'],
dataformat_cache, library_cache)
# Use algorithm names for inputs and outputs
main_channel = self.data['channel'] main_channel = self.data['channel']
# Loads algorithm inputs if self.algorithm.type == Algorithm.LEGACY:
if self.data['proxy_mode']: # Loads algorithm inputs
cache_access = AccessMode.REMOTE if self.data['proxy_mode']:
else: cache_access = AccessMode.REMOTE
cache_access = AccessMode.LOCAL else:
cache_access = AccessMode.LOCAL
(self.input_list, _) = create_inputs_from_configuration( (self.input_list, self.data_loaders, _) = create_inputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root, self.data, self.algorithm, self.prefix, cache_root,
cache_access=cache_access, db_access=AccessMode.REMOTE, cache_access=cache_access, db_access=AccessMode.REMOTE,
socket=self.socket socket=self.socket
) )
# Loads algorithm outputs # Loads algorithm outputs
(self.output_list, _) = create_outputs_from_configuration( (self.output_list, _) = create_outputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root, self.input_list, self.data, self.algorithm, self.prefix, cache_root, self.input_list,
cache_access=cache_access, socket=self.socket cache_access=cache_access, socket=self.socket
) )
else:
(self.input_list, self.data_loaders, _) = create_inputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root,
cache_access=AccessMode.LOCAL, db_access=AccessMode.REMOTE
)
# Loads algorithm outputs
(self.output_list, _) = create_outputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root, self.input_list,
cache_access=AccessMode.LOCAL
)
def setup(self): def setup(self):
...@@ -129,27 +140,56 @@ class Executor(object): ...@@ -129,27 +140,56 @@ class Executor(object):
return retval return retval
def prepare(self):
"""Prepare the algorithm"""
self.runner = self.algorithm.runner()
retval = self.runner.prepare(self.data_loaders)
logger.debug("User algorithm is prepared")
return retval
def process(self): def process(self):
"""Executes the user algorithm code using the current interpreter. """Executes the user algorithm code using the current interpreter.
""" """
if not self.input_list or not self.output_list: if self.algorithm.type == Algorithm.AUTONOMOUS:
raise RuntimeError("I/O for execution block has not yet been set up")
while self.input_list.hasMoreData():
main_group = self.input_list.main_group
main_group.restricted_access = False
main_group.next()
main_group.restricted_access = True
if self.analysis: if self.analysis:
result = self.runner.process(inputs=self.input_list, output=self.output_list[0]) result = self.runner.process(data_loaders=self.data_loaders,
output=self.output_list[0])
else: else:
result = self.runner.process(inputs=self.input_list, outputs=self.output_list) result = self.runner.process(data_loaders=self.data_loaders,
outputs=self.output_list)
if not result: if not result:
return False return False
else:
while self.input_list.hasMoreData():
main_group = self.input_list.main_group
main_group.restricted_access = False
main_group.next()
main_group.restricted_access = True
if self.algorithm.type == Algorithm.LEGACY:
if self.analysis:
result = self.runner.process(inputs=self.input_list, output=self.output_list[0])
else:
result = self.runner.process(inputs=self.input_list, outputs=self.output_list)
elif self.algorithm.type == Algorithm.SEQUENTIAL:
if self.analysis:
result = self.runner.process(inputs=self.input_list,
data_loaders=self.data_loaders,
output=self.output_list[0])
else:
result = self.runner.process(inputs=self.input_list,
data_loaders=self.data_loaders,
outputs=self.output_list)
if not result:
return False
for output in self.output_list: for output in self.output_list:
output.close() output.close()
......
...@@ -32,11 +32,26 @@ import errno ...@@ -32,11 +32,26 @@ import errno
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from . import data from .data import MemoryDataSource
from . import inputs from .data import CachedDataSource
from . import outputs from .data import CachedFileLoader
from .data import CachedDataSink
from .data import getAllFilenames
from .data_loaders import DataLoaderList
from .data_loaders import DataLoader
from .inputs import InputList
from .inputs import Input
from .inputs import RemoteInput
from .inputs import InputGroup
from .outputs import SynchronizationListener
from .outputs import OutputList
from .outputs import Output
from .outputs import RemoteOutput
from .algorithm import Algorithm
#----------------------------------------------------------
def convert_experiment_configuration_to_container(config, proxy_mode): def convert_experiment_configuration_to_container(config, proxy_mode):
data = { data = {
...@@ -80,13 +95,77 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root, ...@@ -80,13 +95,77 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
data_sources = [] data_sources = []
views = {} views = {}
input_list = inputs.InputList() input_list = InputList()
data_loader_list = DataLoaderList()
# This is used for parallelization purposes # This is used for parallelization purposes
start_index, end_index = config.get('range', (None, None)) start_index, end_index = config.get('range', (None, None))
def _create_local_input(details):
data_source = CachedDataSource()
data_sources.append(data_source)
filename = os.path.join(cache_root, details['path'] + '.data')
if details['channel'] == config['channel']: # synchronized
status = data_source.setup(
filename=filename,
prefix=prefix,
force_start_index=start_index,
force_end_index=end_index,
unpack=True,
)
else:
status = data_source.setup(
filename=filename,
prefix=prefix,
unpack=True,
)
if not status:
raise IOError("cannot load cache file `%s'" % details['path'])
input = Input(name, algorithm.input_map[name], data_source)
logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
(name, details['channel'], algorithm.input_map[name], filename))
return input
def _create_data_loader(details):
filename = os.path.join(cache_root, details['path'] + '.data')
data_loader = data_loader_list[details['channel']]
if data_loader is None:
data_loader = DataLoader(details['channel'])
data_loader_list.add(data_loader)
logger.debug("Data loader created: group='%s'" % details['channel'])
cached_file = CachedFileLoader()
result = cached_file.setup(
filename=filename,
prefix=prefix,
start_index=start_index,
end_index=end_index,
unpack=True,
)
if not result:
raise IOError("cannot load cache file `%s'" % details['path'])
data_loader.add(name, cached_file)
logger.debug("Input '%s' added to data loader: group='%s', dataformat='%s', filename='%s'" % \
(name, details['channel'], algorithm.input_map[name], filename))
for name, details in config['inputs'].items(): for name, details in config['inputs'].items():
input = None
if details.get('database', False): if details.get('database', False):
if db_access == AccessMode.LOCAL: if db_access == AccessMode.LOCAL:
if databases is None: if databases is None:
...@@ -114,12 +193,12 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root, ...@@ -114,12 +193,12 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
view = views[channel] view = views[channel]
# Creation of the input # Creation of the input
data_source = data.MemoryDataSource(view.done, next_callback=view.next) data_source = MemoryDataSource(view.done, next_callback=view.next)
output = view.outputs[details['output']] output = view.outputs[details['output']]
output.data_sink.data_sources.append(data_source) output.data_sink.data_sources.append(data_source)
input = inputs.Input(name, algorithm.input_map[name], data_source) input = Input(name, algorithm.input_map[name], data_source)
logger.debug("Input '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \ logger.debug("Input '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
(name, channel, algorithm.input_map[name], details['database'], (name, channel, algorithm.input_map[name], details['database'],
...@@ -129,47 +208,34 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root, ...@@ -129,47 +208,34 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
if socket is None: if socket is None:
raise IOError("No socket provided for remote inputs") raise IOError("No socket provided for remote inputs")
input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]], input = RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
socket, unpack=unpack) socket, unpack=unpack)
logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s', connected to a database" % \ logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s', connected to a database" % \
(name, details['channel'], algorithm.input_map[name])) (name, details['channel'], algorithm.input_map[name]))
elif cache_access == AccessMode.LOCAL: elif cache_access == AccessMode.LOCAL:
data_source = data.CachedDataSource()
data_sources.append(data_source)
filename = os.path.join(cache_root, details['path'] + '.data')
if details['channel'] == config['channel']: # synchronized
status = data_source.setup(
filename=filename,
prefix=prefix,
force_start_index=start_index,
force_end_index=end_index,
unpack=True,
)
else:
status = data_source.setup(
filename=filename,
prefix=prefix,
unpack=True,
)
if not status: if algorithm.type == Algorithm.LEGACY:
raise IOError("cannot load cache file `%s'" % details['path']) input = _create_local_input(details)
input = inputs.Input(name, algorithm.input_map[name], data_source) elif algorithm.type == Algorithm.SEQUENTIAL:
if details['channel'] == config['channel']: # synchronized
input = _create_local_input(details)
else:
_create_data_loader(details)
elif algorithm.type == Algorithm.AUTONOMOUS:
_create_data_loader(details)
logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
(name, details['channel'], algorithm.input_map[name], filename))
elif cache_access == AccessMode.REMOTE: elif cache_access == AccessMode.REMOTE:
if socket is None: if socket is None:
raise IOError("No socket provided for remote inputs") raise IOError("No socket provided for remote inputs")
input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]], input = RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
socket, unpack=unpack) socket, unpack=unpack)
logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s'" % \ logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s'" % \
(name, details['channel'], algorithm.input_map[name])) (name, details['channel'], algorithm.input_map[name]))
...@@ -178,24 +244,24 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root, ...@@ -178,24 +244,24 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
continue continue
# Synchronization bits # Synchronization bits
group = input_list.group(details['channel']) if input is not None:
if group is None: group = input_list.group(details['channel'])
synchronization_listener = None if group is None:
if not no_synchronisation_listeners: synchronization_listener = None
synchronization_listener = outputs.SynchronizationListener() if not no_synchronisation_listeners:
synchronization_listener = SynchronizationListener()
group = inputs.InputGroup(
details['channel'],
synchronization_listener=synchronization_listener,
restricted_access=(details['channel'] == config['channel'])
)
input_list.add(group)
logger.debug("Group '%s' created" % details['channel'])
group.add(input) group = InputGroup(
details['channel'],
synchronization_listener=synchronization_listener,
restricted_access=(details['channel'] == config['channel'])
)
input_list.add(group)
logger.debug("Group '%s' created" % details['channel'])
return (input_list, data_sources) group.add(input)
return (input_list, data_loader_list, data_sources)
#---------------------------------------------------------- #----------------------------------------------------------
...@@ -205,7 +271,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp ...@@ -205,7 +271,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
cache_access=AccessMode.NONE, socket=None): cache_access=AccessMode.NONE, socket=None):
data_sinks = [] data_sinks = []
output_list = outputs.OutputList() output_list = OutputList()
# This is used for parallelization purposes # This is used for parallelization purposes
start_index, end_index = config.get('range', (None, None)) start_index, end_index = config.get('range', (None, None))
...@@ -254,7 +320,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp ...@@ -254,7 +320,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
break break
(data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \ (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \
data.getAllFilenames(input_path) getAllFilenames(input_path)
end_indices = [ int(x.split('.')[-2]) for x in indices_filenames ] end_indices = [ int(x.split('.')[-2]) for x in indices_filenames ]
end_indices.sort() end_indices.sort()
...@@ -262,7 +328,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp ...@@ -262,7 +328,7 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
start_index = 0 start_index = 0
end_index = end_indices[-1] end_index = end_indices[-1]
data_sink = data.CachedDataSink() data_sink = CachedDataSink()
data_sinks.append(data_sink) data_sinks.append(data_sink)
status = data_sink.setup( status = data_sink.setup(
...@@ -276,9 +342,9 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp ...@@ -276,9 +342,9 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
if not status: if not status:
raise IOError("Cannot create cache sink '%s'" % details['path']) raise IOError("Cannot create cache sink '%s'" % details['path'])
output_list.add(outputs.Output(name, data_sink, output_list.add(Output(name, data_sink,
synchronization_listener=synchronization_listener, synchronization_listener=synchronization_listener,
force_start_index=start_index) force_start_index=start_index)
) )
if 'result' not in config: if 'result' not in config:
...@@ -292,9 +358,9 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp ...@@ -292,9 +358,9 @@ def create_outputs_from_configuration(config, algorithm, prefix, cache_root, inp
if socket is None: if socket is None:
raise IOError("No socket provided for remote outputs") raise IOError("No socket provided for remote outputs")
output_list.add(outputs.RemoteOutput(name, dataformat, socket, output_list.add(RemoteOutput(name, dataformat, socket,
synchronization_listener=synchronization_listener, synchronization_listener=synchronization_listener,
force_start_index=start_index or 0) force_start_index=start_index or 0)
) )
logger.debug("RemoteOutput '%s' created: group='%s', dataformat='%s'" % \ logger.debug("RemoteOutput '%s' created: group='%s', dataformat='%s'" % \
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.backend.python module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
import unittest
import tempfile
import simplejson
import os
import zmq
import shutil
import numpy as np
from copy import deepcopy
from ..executor import Executor
from ..message_handler import MessageHandler
from ..inputs import InputList
from ..algorithm import Algorithm
from ..dataformat import DataFormat
from ..data import CachedDataSink
from ..data import CachedFileLoader
from ..helpers import convert_experiment_configuration_to_container
from ..helpers import create_inputs_from_configuration
from ..helpers import create_outputs_from_configuration
from ..helpers import AccessMode
from . import prefix
CONFIGURATION = {
'algorithm': '',
'channel': 'main',
'parameters': {
},
'inputs': {
'in': {
'path': 'INPUT',
'channel': 'main',
}
},
'outputs': {
'out': {
'path': 'OUTPUT',
'channel': 'main'
}
},
}
#----------------------------------------------------------
class TestExecutor(unittest.TestCase):
def setUp(self):
self.cache_root = tempfile.mkdtemp(prefix=__name__)
self.working_dir = tempfile.mkdtemp(prefix=__name__)
self.message_handler = None
self.executor_socket = None
self.zmq_context = None
def tearDown(self):
shutil.rmtree(self.cache_root)
shutil.rmtree(self.working_dir)