Commit 2b9921e6 authored by Philip ABBET's avatar Philip ABBET

Merge optimization-related changes

parents 70b2d3ac 86ed17a3
......@@ -123,6 +123,7 @@ class CachedDataSource(DataSource):
self._cache = six.b('')
self._nb_bytes_read = 0
self._read_duration = 0
self._unpack = True
def _readHeader(self):
"""Read the header of the current file"""
......@@ -155,7 +156,7 @@ class CachedDataSource(DataSource):
return True
def setup(self, filename, prefix, force_start_index=None,
force_end_index=None):
force_end_index=None, unpack=True):
"""Configures the data source
......@@ -171,6 +172,8 @@ class CachedDataSource(DataSource):
force_end_index (int): The end index (if not set or set to ``None``, the
default, reads the data until the end)
unpack (bool): Indicates if the data must be unpacked or not
Returns:
......@@ -263,6 +266,8 @@ class CachedDataSource(DataSource):
self.filenames = trim_filename(data_filenames, force_start_index,
force_end_index)
self._unpack = unpack
# Read the first file to process
self.cur_file_index = 0
try:
......@@ -349,8 +354,11 @@ class CachedDataSource(DataSource):
encoded_data = self._cache + data
self._cache = six.b('')
data = self.dataformat.type()
data.unpack(encoded_data) #checks validity
if self._unpack:
data = self.dataformat.type()
data.unpack(encoded_data) #checks validity
else:
data = encoded_data
result = (data, self.next_start_index, self.next_end_index)
......
......@@ -40,6 +40,10 @@ import simplejson
from . import algorithm
from . import inputs
from . import outputs
from . import stats
from .helpers import create_inputs_from_configuration
from .helpers import create_outputs_from_configuration
from .helpers import CacheAccess
class Executor(object):
......@@ -78,10 +82,11 @@ class Executor(object):
database_cache=None, library_cache=None):
self.socket = socket
self.comm_time = 0. #total communication time
self.configuration = os.path.join(directory, 'configuration.json')
with open(self.configuration, 'rb') as f: self.data = simplejson.load(f)
with open(self.configuration, 'rb') as f:
self.data = simplejson.load(f)
self.prefix = os.path.join(directory, 'prefix')
self.runner = None
......@@ -97,36 +102,21 @@ class Executor(object):
main_channel = self.data['channel']
# Loads algorithm inputs
if 'inputs' in self.data:
self.input_list = inputs.InputList()
for name, channel in self.data['inputs'].items():
group = self.input_list.group(channel)
if group is None:
group = inputs.InputGroup(channel, restricted_access=(channel == main_channel))
self.input_list.add(group)
thisformat = self.algorithm.dataformats[self.algorithm.input_map[name]]
group.add(inputs.RemoteInput(name, thisformat, self.socket))
logger.debug("Loaded input list with %d group(s) and %d input(s)",
self.input_list.nbGroups(), len(self.input_list))
# Loads outputs
if 'outputs' in self.data:
self.output_list = outputs.OutputList()
for name, channel in self.data['outputs'].items():
thisformat = self.algorithm.dataformats[self.algorithm.output_map[name]]
self.output_list.add(outputs.RemoteOutput(name, thisformat, self.socket))
logger.debug("Loaded output list with %d output(s)",
len(self.output_list))
# Loads results if it is an analyzer
if 'result' in self.data:
self.output_list = outputs.OutputList()
name = 'result'
# Retrieve dataformats in the JSON of the algorithm
analysis_format = self.algorithm.result_dataformat()
analysis_format.name = 'analysis:' + self.algorithm.name
self.output_list.add(outputs.RemoteOutput(name, analysis_format, self.socket))
logger.debug("Loaded output list for analyzer (1 single output)")
if self.data['proxy_mode']:
cache_access = CacheAccess.REMOTE
else:
cache_access = CacheAccess.LOCAL
(self.input_list, _) = create_inputs_from_configuration(
self.data, self.algorithm, self.prefix, '/cache',
cache_access=cache_access, socket=self.socket
)
# Loads algorithm outputs
(self.output_list, _) = create_outputs_from_configuration(
self.data, self.algorithm, self.prefix, '/cache', self.input_list,
cache_access=cache_access, socket=self.socket
)
def setup(self):
......@@ -147,45 +137,37 @@ class Executor(object):
using_output = self.output_list[0] if self.analysis else self.output_list
_start = time.time()
while self.input_list.hasMoreData():
main_group = self.input_list.main_group
main_group.restricted_access = False
main_group.next()
main_group.restricted_access = True
if not self.runner.process(self.input_list, using_output): return False
if not self.runner.process(self.input_list, using_output):
return False
missing_data_outputs = [x for x in self.output_list if x.isDataMissing()]
proc_time = time.time() - _start
if missing_data_outputs:
raise RuntimeError("Missing data on the following output(s): %s" % \
', '.join([x.name for x in missing_data_outputs]))
self.comm_time = sum([x.comm_time for x in self.input_list]) + \
sum([x.comm_time for x in self.output_list])
self.comm_time += sum([self.input_list[k].comm_time for k in range(self.input_list.nbGroups())])
# Send the done command
statistics = stats.io_statistics(self.data, self.input_list, self.output_list)
# some local information
logger.debug("Total processing time was %.3f seconds" , proc_time)
logger.debug("Time spent in I/O was %.3f seconds" , self.comm_time)
logger.debug("I/O/Processing ratio is %d%%",
100*self.comm_time/proc_time)
logger.debug("Statistics: " + simplejson.dumps(statistics, indent=4))
# Handle the done command
self.done()
self.done(statistics)
return True
def done(self):
def done(self, statistics):
"""Indicates the infrastructure the execution is done"""
logger.debug('send: (don) done')
self.socket.send('don', zmq.SNDMORE)
self.socket.send('%.6e' % self.comm_time)
self.socket.send(simplejson.dumps(statistics))
answer = self.socket.recv() #ack
logger.debug('recv: %s', answer)
......
......@@ -112,4 +112,8 @@ def hashFileContents(path):
"""Hashes the file contents using :py:func:`hashlib.sha256`."""
with open(path, 'rb') as f:
return hashlib.sha256(f.read()).hexdigest()
sha256 = hashlib.sha256()
for chunk in iter(lambda: f.read(sha256.block_size * 1000), b''):
sha256.update(chunk)
return sha256.hexdigest()
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.backend.python module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
import os
import errno
import logging
logger = logging.getLogger(__name__)
from . import data
from . import inputs
from . import outputs
def convert_experiment_configuration_to_container(config, proxy_mode):
data = {
'proxy_mode': proxy_mode,
'algorithm': config['algorithm'],
'parameters': config['parameters'],
'channel': config['channel'],
}
if 'range' in config:
data['range'] = config['range']
data['inputs'] = \
dict([(k, { 'channel': v['channel'], 'path': v['path'], 'database': v.has_key('database') }) for k,v in config['inputs'].items()])
if 'outputs' in config:
data['outputs'] = \
dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['outputs'].items()])
else:
data['result'] = { 'channel': config['channel'], 'path': config['result']['path'] }
return data
#----------------------------------------------------------
class CacheAccess:
NONE = 0
LOCAL = 1
REMOTE = 2
def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
cache_access=CacheAccess.NONE, unpack=True,
socket=None):
data_sources = []
input_list = inputs.InputList()
# This is used for parallelization purposes
start_index, end_index = config.get('range', (None, None))
for name, details in config['inputs'].items():
if details.get('database', False):
if socket is None:
raise IOError("No socket provided for remote inputs")
input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
socket, unpack=unpack)
logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s', connected to a database" % \
(name, details['channel'], algorithm.input_map[name]))
elif cache_access == CacheAccess.LOCAL:
data_source = data.CachedDataSource()
data_sources.append(data_source)
filename = os.path.join(cache_root, details['path'] + '.data')
if details['channel'] == config['channel']: # synchronized
status = data_source.setup(
filename=filename,
prefix=prefix,
force_start_index=start_index,
force_end_index=end_index,
unpack=unpack,
)
else:
status = data_source.setup(
filename=filename,
prefix=prefix,
unpack=unpack,
)
if not status:
raise IOError("cannot load cache file `%s'" % details['path'])
input = inputs.Input(name, algorithm.input_map[name], data_source)
logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
(name, details['channel'], algorithm.input_map[name], filename))
elif cache_access == CacheAccess.REMOTE:
if socket is None:
raise IOError("No socket provided for remote inputs")
input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
socket, unpack=unpack)
logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s'" % \
(name, details['channel'], algorithm.input_map[name]))
else:
continue
# Synchronization bits
group = input_list.group(details['channel'])
if group is None:
group = inputs.InputGroup(
details['channel'],
synchronization_listener=outputs.SynchronizationListener(),
restricted_access=(details['channel'] == config['channel'])
)
input_list.add(group)
logger.debug("Group '%s' created" % details['channel'])
group.add(input)
return (input_list, data_sources)
#----------------------------------------------------------
def create_outputs_from_configuration(config, algorithm, prefix, cache_root, input_list,
cache_access=CacheAccess.NONE, socket=None):
data_sinks = []
output_list = outputs.OutputList()
# This is used for parallelization purposes
start_index, end_index = config.get('range', (None, None))
# If the algorithm is an analyser
if 'result' in config:
output_config = {
'result': config['result']
}
else:
output_config = config['outputs']
for name, details in output_config.items():
if 'result' in config:
dataformat_name = 'analysis:' + algorithm.name
dataformat = algorithm.result_dataformat()
else:
dataformat_name = algorithm.output_map[name]
dataformat = algorithm.dataformats[dataformat_name]
if cache_access == CacheAccess.LOCAL:
path = os.path.join(cache_root, details['path'] + '.data')
dirname = os.path.dirname(path)
# Make sure that the directory exists while taking care of race
# conditions. see: http://stackoverflow.com/questions/273192/check-if-a-directory-exists-and-create-it-if-necessary
try:
if (len(dirname) > 0):
os.makedirs(dirname)
except OSError as exception:
if exception.errno != errno.EEXIST:
raise
data_sink = data.CachedDataSink()
data_sinks.append(data_sink)
status = data_sink.setup(
filename=path,
dataformat=dataformat,
encoding='binary',
max_size=0, # in bytes, for individual file chunks
)
if not status:
raise IOError("Cannot create cache sink '%s'" % details['path'])
synchronization_listener = None
if 'result' not in config:
input_group = input_list.group(details['channel'])
if (input_group is not None) and hasattr(input_group, 'synchronization_listener'):
synchronization_listener = input_group.synchronization_listener
output_list.add(outputs.Output(name, data_sink,
synchronization_listener=synchronization_listener,
force_start_index=start_index or 0)
)
if 'result' not in config:
logger.debug("Output '%s' created: group='%s', dataformat='%s', filename='%s'" % \
(name, details['channel'], dataformat_name, path))
else:
logger.debug("Output '%s' created: dataformat='%s', filename='%s'" % \
(name, dataformat_name, path))
elif cache_access == CacheAccess.REMOTE:
if socket is None:
raise IOError("No socket provided for remote outputs")
output_list.add(outputs.RemoteOutput(name, dataformat, socket))
logger.debug("RemoteOutput '%s' created: group='%s', dataformat='%s'" % \
(name, details['channel'], dataformat_name))
else:
continue
return (output_list, data_sinks)
......@@ -105,6 +105,14 @@ class Input:
"""Retrieves the next block of data"""
(self.data, self.data_index, self.data_index_end) = self.data_source.next()
if self.data is None:
message = "User algorithm asked for more data for channel " \
"`%s' on input `%s', but it is over (no more data). This " \
"normally indicates a programming error on the user " \
"side." % (self.group.channel, self.name)
raise RuntimeError(message)
self.data_same_as_previous = False
self.nb_data_blocks_read += 1
......@@ -166,7 +174,7 @@ class RemoteInput:
"""
def __init__(self, name, data_format, socket):
def __init__(self, name, data_format, socket, unpack=True):
self.name = str(name)
self.data_format = data_format
......@@ -177,6 +185,7 @@ class RemoteInput:
self.group = None
self.comm_time = 0. #total time spent on communication
self.nb_data_blocks_read = 0
self._unpack = unpack
def isDataUnitDone(self):
......@@ -248,10 +257,14 @@ class RemoteInput:
def unpack(self, packed):
"""Receives data through socket"""
self.data = self.data_format.type()
logger.debug('recv: <bin> (size=%d), indexes=(%d, %d)', len(packed),
self.data_index, self.data_index_end)
self.data.unpack(packed)
if self.unpack:
self.data = self.data_format.type()
self.data.unpack(packed)
else:
self.data = packed
#----------------------------------------------------------
......@@ -585,8 +598,13 @@ class InputList:
return bool([x for x in self._groups if x.hasMoreData()])
def group(self, name):
try:
return [x for x in self._groups if x.channel == name][0]
except:
def group(self, name_or_index):
if isinstance(name_or_index, six.string_types):
try:
return [x for x in self._groups if x.channel == name_or_index][0]
except:
return None
elif isinstance(name_or_index, int):
return self._groups[name_or_index]
else:
return None
......@@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
import gevent
import zmq.green as zmq
import simplejson
import requests
from gevent import monkey
......@@ -128,6 +129,13 @@ class MessageHandler(gevent.Greenlet):
self.process.kill()
self.stop.set()
break
except RuntimeError as e:
self.send_error(str(e), kind='usr')
self.user_error = str(e)
if self.process is not None:
self.process.kill()
self.stop.set()
break
except:
import traceback
parser = lambda s: s if len(s)<20 else s[:20] + '...'
......@@ -272,17 +280,16 @@ class MessageHandler(gevent.Greenlet):
self.stop.set()
def done(self, wait_time=None):
def done(self, statistics=None):
"""Syntax: don"""
logger.debug('recv: don %s', wait_time)
logger.debug('recv: don %s', statistics)
if wait_time is not None:
if statistics is not None:
self._collect_statistics()
# collect I/O stats from client
wait_time = float(wait_time)
self.last_statistics['data'] = dict(network=dict(wait_time=wait_time))
self.last_statistics['data'] = simplejson.loads(statistics)
self._acknowledge()
......
......@@ -52,6 +52,8 @@ import logging
import os
import sys
import docopt
import pwd
import stat
import zmq
......@@ -121,6 +123,28 @@ def main():
logger = logging.getLogger(__name__)
# Attempt to change to an user with less privileges
try:
# First determine if the user exists. If not, none of the following lines will
# be executed
newuid = pwd.getpwnam('beat-nobody').pw_uid
# Next, ensure that the needed files are readable by the 'beat-nobody' user
access = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH
os.chmod(args['<dir>'], access)
for root, dirs, files in os.walk(args['<dir>']):
for d in dirs:
os.chmod(os.path.join(root, d), access)
for f in files:
os.chmod(os.path.join(root, f), access)
# Change the user
os.setuid(newuid)
except:
pass
# Creates the 0MQ socket for communication with BEAT
context = zmq.Context()
socket = context.socket(zmq.PAIR)
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.backend.python module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
from .inputs import RemoteInput
from .outputs import RemoteOutput
def io_statistics(configuration, input_list=None, output_list=None):
"""Summarize current I/O statistics looking at data sources and sinks, inputs and outputs
Returns:
dict: A dictionary summarizing current I/O statistics
"""
network_time = 0.0
# Data reading
bytes_read = 0
blocks_read = 0
read_time = 0.0
if input_list is not None:
for input in input_list:
if isinstance(input, RemoteInput):
network_time += input.comm_time
else:
size, duration = input.data_source.statistics()
bytes_read += size
read_time += duration
blocks_read += input.nb_data_blocks_read
network_time += sum([input_list.group(k).comm_time for k in range(input_list.nbGroups())])
# Data writing
bytes_written = 0
blocks_written = 0
write_time = 0.0
files = []
if output_list is not None:
for output in output_list:
if isinstance(output, RemoteOutput):
network_time += output.comm_time
else:
size, duration = output.data_sink.statistics()
bytes_written += size
write_time += duration
blocks_written += output.nb_data_blocks_written
if 'result' in configuration:
hash = configuration['result']['path'].replace('/', '')
else:
hash = configuration['outputs'][output.name]['path'].replace('/', '')
files.append(dict(
hash=hash,
size=size,
blocks=output.nb_data_blocks_written,
))
# Result
return dict(
volume = dict(read=bytes_read, write=bytes_written),
blocks = dict(read=blocks_read, write=blocks_written),
time = dict(read=read_time, write=write_time),
network = dict(wait_time=network_time),
files = files,
)
#----------------------------------------------------------
def update(statistics, additional_statistics):
for k in statistics.keys():
if k == 'files':
continue
for k2 in statistics[k].keys():
statistics[k][k2] += additional_statistics[k][k2]
if 'files' in statistics:
statistics['files'].extend(additional_statistics.get('files', []))
else:
statistics['files'] = additional_statistics.get('files', [])
......@@ -6,7 +6,7 @@
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.core module of the BEAT platform. #
# This file is part of the beat.backend.python module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
......@@ -51,6 +51,8 @@ def hashed_or_simple(prefix, what, path, suffix='.json'):
return os.path.join(prefix, what, path)
#----------------------------------------------------------
def safe_rmfile(f):