Commit 8fcb4366 authored by Philip ABBET's avatar Philip ABBET

Add support for the 'no proxy' mode (for inputs) in containers

parent a9528aff
......@@ -40,6 +40,7 @@ import simplejson
from . import algorithm
from . import inputs
from . import outputs
from . import stats
from .helpers import create_inputs_from_configuration
from .helpers import CacheAccess
......@@ -80,10 +81,11 @@ class Executor(object):
database_cache=None, library_cache=None):
self.socket = socket
self.comm_time = 0. #total communication time
self.configuration = os.path.join(directory, 'configuration.json')
with open(self.configuration, 'rb') as f: self.data = simplejson.load(f)
with open(self.configuration, 'rb') as f:
self.data = simplejson.load(f)
self.prefix = os.path.join(directory, 'prefix')
self.runner = None
......@@ -147,45 +149,37 @@ class Executor(object):
using_output = self.output_list[0] if self.analysis else self.output_list
_start = time.time()
while self.input_list.hasMoreData():
main_group = self.input_list.main_group
main_group.restricted_access = False
main_group.next()
main_group.restricted_access = True
if not self.runner.process(self.input_list, using_output): return False
if not self.runner.process(self.input_list, using_output):
return False
missing_data_outputs = [x for x in self.output_list if x.isDataMissing()]
proc_time = time.time() - _start
if missing_data_outputs:
raise RuntimeError("Missing data on the following output(s): %s" % \
', '.join([x.name for x in missing_data_outputs]))
self.comm_time = sum([x.comm_time for x in self.input_list]) + \
sum([x.comm_time for x in self.output_list])
self.comm_time += sum([self.input_list[k].comm_time for k in range(self.input_list.nbGroups())])
# Send the done command
statistics = stats.io_statistics(self.data, self.input_list, self.output_list)
# some local information
logger.debug("Total processing time was %.3f seconds" , proc_time)
logger.debug("Time spent in I/O was %.3f seconds" , self.comm_time)
logger.debug("I/O/Processing ratio is %d%%",
100*self.comm_time/proc_time)
logger.debug("Statistics: " + simplejson.dumps(statistics, indent=4))
# Handle the done command
self.done()
self.done(statistics)
return True
def done(self):
def done(self, statistics):
"""Indicates the infrastructure the execution is done"""
logger.debug('send: (don) done')
self.socket.send('don', zmq.SNDMORE)
self.socket.send('%.6e' % self.comm_time)
self.socket.send(simplejson.dumps(statistics))
answer = self.socket.recv() #ack
logger.debug('recv: %s', answer)
......
......@@ -26,6 +26,11 @@
###############################################################################
import os
import logging
logger = logging.getLogger(__name__)
from . import data
from . import inputs
from . import outputs
......@@ -44,7 +49,7 @@ def convert_experiment_configuration_to_container(config, proxy_mode):
data['range'] = config['range']
data['inputs'] = \
dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['inputs'].items()])
dict([(k, { 'channel': v['channel'], 'path': v['path'], 'database': v.has_key('database') }) for k,v in config['inputs'].items()])
if 'outputs' in config:
data['outputs'] = \
......@@ -76,21 +81,25 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
for name, details in config['inputs'].items():
if 'database' in details:
# create the remote input
if details.get('database', False):
if socket is None:
raise IOError("No socket provided for remote inputs")
input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
socket, unpack=unpack)
elif cache_access != CacheAccess.NONE:
logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s', connected to a database" % \
(name, details['channel'], algorithm.input_map[name]))
elif cache_access == CacheAccess.LOCAL:
data_source = data.CachedDataSource()
data_sources.append(data_source)
filename = os.path.join(cache_root, details['path'] + '.data')
if details['channel'] == config['channel']: # synchronized
status = data_source.setup(
filename=os.path.join(cache_root, details['path'] + '.data'),
filename=filename,
prefix=prefix,
force_start_index=start_index,
force_end_index=end_index,
......@@ -98,7 +107,7 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
)
else:
status = data_source.setup(
filename=os.path.join(cache_root, details['path'] + '.data'),
filename=filename,
prefix=prefix,
unpack=unpack,
)
......@@ -106,11 +115,20 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
if not status:
raise IOError("cannot load cache file `%s'" % details['path'])
if cache_access == CacheAccess.LOCAL:
input = inputs.Input(name, algorithm.input_map[name], data_source)
else:
input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
socket, unpack=unpack)
input = inputs.Input(name, algorithm.input_map[name], data_source)
logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
(name, details['channel'], algorithm.input_map[name], filename))
elif cache_access == CacheAccess.REMOTE:
if socket is None:
raise IOError("No socket provided for remote inputs")
input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
socket, unpack=unpack)
logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s'" % \
(name, details['channel'], algorithm.input_map[name]))
else:
continue
......@@ -124,6 +142,7 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
restricted_access=(details['channel'] == config['channel'])
)
input_list.add(group)
logger.debug("Group '%s' created" % details['channel'])
group.add(input)
......
......@@ -105,6 +105,14 @@ class Input:
"""Retrieves the next block of data"""
(self.data, self.data_index, self.data_index_end) = self.data_source.next()
if self.data is None:
message = "User algorithm asked for more data for channel " \
"`%s' on input `%s', but it is over (no more data). This " \
"normally indicates a programming error on the user " \
"side." % (self.group.channel, self.name)
raise RuntimeError(message)
self.data_same_as_previous = False
self.nb_data_blocks_read += 1
......@@ -590,8 +598,13 @@ class InputList:
return bool([x for x in self._groups if x.hasMoreData()])
def group(self, name):
try:
return [x for x in self._groups if x.channel == name][0]
except:
def group(self, name_or_index):
if isinstance(name_or_index, six.string_types):
try:
return [x for x in self._groups if x.channel == name_or_index][0]
except:
return None
elif isinstance(name_or_index, int):
return self._groups[name_or_index]
else:
return None
......@@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
import gevent
import zmq.green as zmq
import simplejson
import requests
from gevent import monkey
......@@ -272,17 +273,16 @@ class MessageHandler(gevent.Greenlet):
self.stop.set()
def done(self, wait_time=None):
def done(self, statistics=None):
"""Syntax: don"""
logger.debug('recv: don %s', wait_time)
logger.debug('recv: don %s', statistics)
if wait_time is not None:
if statistics is not None:
self._collect_statistics()
# collect I/O stats from client
wait_time = float(wait_time)
self.last_statistics['data'] = dict(network=dict(wait_time=wait_time))
self.last_statistics['data'] = simplejson.loads(statistics)
self._acknowledge()
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.backend.python module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
from .inputs import RemoteInput
from .outputs import RemoteOutput
def io_statistics(configuration, input_list=None, output_list=None):
"""Summarize current I/O statistics looking at data sources and sinks, inputs and outputs
Returns:
dict: A dictionary summarizing current I/O statistics
"""
network_time = 0.0
# Data reading
bytes_read = 0
blocks_read = 0
read_time = 0.0
if input_list is not None:
for input in input_list:
if isinstance(input, RemoteInput):
network_time += input.comm_time
else:
size, duration = input.data_source.statistics()
bytes_read += size
read_time += duration
blocks_read += input.nb_data_blocks_read
network_time += sum([input_list.group(k).comm_time for k in range(input_list.nbGroups())])
# Data writing
bytes_written = 0
blocks_written = 0
write_time = 0.0
files = []
if output_list is not None:
for output in output_list:
if isinstance(output, RemoteOutput):
network_time += output.comm_time
else:
size, duration = output.data_sink.statistics()
bytes_written += size
write_time += duration
blocks_written += output.nb_data_blocks_written
if 'result' in configuration:
hash = configuration['result']['hash']
else:
hash = configuration['outputs'][output.name]['hash']
files.append(dict(
hash=hash,
size=size,
blocks=output.nb_data_blocks_written,
))
# Result
return dict(
volume = dict(read=bytes_read, write=bytes_written),
blocks = dict(read=blocks_read, write=blocks_written),
time = dict(read=read_time, write=write_time),
network = dict(wait_time=network_time),
files = files,
)
#----------------------------------------------------------
def update(statistics, additional_statistics):
for k in statistics.keys():
if k == 'files':
continue
for k2 in statistics[k].keys():
statistics[k][k2] += additional_statistics[k][k2]
if 'files' in statistics:
statistics['files'].extend(additional_statistics.get('files', []))
else:
statistics['files'] = additional_statistics.get('files', [])
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment