Skip to content
Snippets Groups Projects
Commit a9528aff authored by Philip ABBET's avatar Philip ABBET
Browse files

Add helper functions to create input list from a configuration file for various scenarios

parent 52470e1d
No related branches found
No related tags found
No related merge requests found
......@@ -40,6 +40,8 @@ import simplejson
from . import algorithm
from . import inputs
from . import outputs
from .helpers import create_inputs_from_configuration
from .helpers import CacheAccess
class Executor(object):
......@@ -97,17 +99,15 @@ class Executor(object):
main_channel = self.data['channel']
# Loads algorithm inputs
if 'inputs' in self.data:
self.input_list = inputs.InputList()
for name, channel in self.data['inputs'].items():
group = self.input_list.group(channel)
if group is None:
group = inputs.InputGroup(channel, restricted_access=(channel == main_channel))
self.input_list.add(group)
thisformat = self.algorithm.dataformats[self.algorithm.input_map[name]]
group.add(inputs.RemoteInput(name, thisformat, self.socket))
logger.debug("Loaded input list with %d group(s) and %d input(s)",
self.input_list.nbGroups(), len(self.input_list))
if self.data['proxy_mode']:
cache_access = CacheAccess.REMOTE
else:
cache_access = CacheAccess.LOCAL
(self.input_list, _) = create_inputs_from_configuration(
self.data, self.algorithm, self.prefix, '/cache',
cache_access=cache_access, socket=self.socket
)
# Loads outputs
if 'outputs' in self.data:
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.backend.python module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
from . import data
from . import inputs
from . import outputs
def convert_experiment_configuration_to_container(config, proxy_mode):
data = {
'proxy_mode': proxy_mode,
'algorithm': config['algorithm'],
'parameters': config['parameters'],
'channel': config['channel'],
}
if 'range' in config:
data['range'] = config['range']
data['inputs'] = \
dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['inputs'].items()])
if 'outputs' in config:
data['outputs'] = \
dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['outputs'].items()])
else:
data['result'] = { 'channel': config['channel'], 'path': config['result']['path'] }
return data
#----------------------------------------------------------
class CacheAccess:
NONE = 0
LOCAL = 1
REMOTE = 2
def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
cache_access=CacheAccess.NONE, unpack=True,
socket=None):
data_sources = []
input_list = inputs.InputList()
# This is used for parallelization purposes
start_index, end_index = config.get('range', (None, None))
for name, details in config['inputs'].items():
if 'database' in details:
# create the remote input
if socket is None:
raise IOError("No socket provided for remote inputs")
input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
socket, unpack=unpack)
elif cache_access != CacheAccess.NONE:
data_source = data.CachedDataSource()
data_sources.append(data_source)
if details['channel'] == config['channel']: # synchronized
status = data_source.setup(
filename=os.path.join(cache_root, details['path'] + '.data'),
prefix=prefix,
force_start_index=start_index,
force_end_index=end_index,
unpack=unpack,
)
else:
status = data_source.setup(
filename=os.path.join(cache_root, details['path'] + '.data'),
prefix=prefix,
unpack=unpack,
)
if not status:
raise IOError("cannot load cache file `%s'" % details['path'])
if cache_access == CacheAccess.LOCAL:
input = inputs.Input(name, algorithm.input_map[name], data_source)
else:
input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
socket, unpack=unpack)
else:
continue
# Synchronization bits
group = input_list.group(details['channel'])
if group is None:
group = inputs.InputGroup(
details['channel'],
synchronization_listener=outputs.SynchronizationListener(),
restricted_access=(details['channel'] == config['channel'])
)
input_list.add(group)
group.add(input)
return (input_list, data_sources)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment