Commit c0b0eb46 authored by Philip ABBET's avatar Philip ABBET

Add tests for the 'databases_provider.py' script

parent ab6234d3
......@@ -64,6 +64,12 @@ class RemoteException(Exception):
self.system_error = ''
self.user_error = message
def __str__(self):
if self.system_error != '':
return '(sys) ' + self.system_error
else:
return '(usr) ' + self.user_error
#----------------------------------------------------------
......@@ -195,6 +201,7 @@ class DataSource(object):
self.infos = []
self.read_duration = 0
self.nb_bytes_read = 0
self.ready = False
def close(self):
......@@ -207,10 +214,16 @@ class DataSource(object):
def __len__(self):
if not self.ready:
self._prepare()
return len(self.infos)
def __iter__(self):
if not self.ready:
self._prepare()
for i in range(0, len(self.infos)):
yield self[i]
......@@ -220,18 +233,30 @@ class DataSource(object):
def first_data_index(self):
if not self.ready:
self._prepare()
return self.infos[0].start_index
def last_data_index(self):
if not self.ready:
self._prepare()
return self.infos[-1].end_index
def data_indices(self):
if not self.ready:
self._prepare()
return [ (x.start_index, x.end_index) for x in self.infos ]
def getAtDataIndex(self, data_index):
if not self.ready:
self._prepare()
for index, infos in enumerate(self.infos):
if (infos.start_index <= data_index) and (data_index <= infos.end_index):
return self[index]
......@@ -244,6 +269,10 @@ class DataSource(object):
return (self.nb_bytes_read, self.read_duration)
def _prepare(self):
self.ready = True
#----------------------------------------------------------
......@@ -433,6 +462,9 @@ class CachedDataSource(DataSource):
"""
if not self.ready:
self._prepare()
if (index < 0) or (index >= len(self.infos)):
return (None, None, None)
......@@ -561,6 +593,9 @@ class DatabaseOutputDataSource(DataSource):
"""
if not self.ready:
self._prepare()
if (index < 0) or (index >= len(self.infos)):
return (None, None, None)
......@@ -632,29 +667,6 @@ class RemoteDataSource(DataSource):
if not self.dataformat.valid:
raise RuntimeError("the dataformat `%s' is not valid" % dataformat_name)
# Load the needed infos from the socket
Infos = namedtuple('Infos', ['start_index', 'end_index'])
logger.debug('send: (ifo) infos %s', self.input_name)
self.socket.send('ifo', zmq.SNDMORE)
self.socket.send(self.input_name)
answer = self.socket.recv()
logger.debug('recv: %s', answer)
if answer == 'err':
kind = self.socket.recv()
message = self.socket.recv()
raise RemoteException(kind, message)
nb = int(answer)
for i in range(nb):
start = int(self.socket.recv())
end = int(self.socket.recv())
self.infos.append(Infos(start_index=start, end_index=end))
return True
......@@ -667,6 +679,9 @@ class RemoteDataSource(DataSource):
"""
if not self.ready:
self._prepare()
if (index < 0) or (index >= len(self.infos)):
return (None, None, None)
......@@ -710,6 +725,32 @@ class RemoteDataSource(DataSource):
return (data, infos.start_index, infos.end_index)
def _prepare(self):
# Load the needed infos from the socket
Infos = namedtuple('Infos', ['start_index', 'end_index'])
logger.debug('send: (ifo) infos %s', self.input_name)
self.socket.send('ifo', zmq.SNDMORE)
self.socket.send(self.input_name)
answer = self.socket.recv()
logger.debug('recv: %s', answer)
if answer == 'err':
kind = self.socket.recv()
message = self.socket.recv()
raise RemoteException(kind, message)
nb = int(answer)
for i in range(nb):
start = int(self.socket.recv())
end = int(self.socket.recv())
self.infos.append(Infos(start_index=start, end_index=end))
self.ready = True
#----------------------------------------------------------
......
......@@ -250,6 +250,9 @@ class Database(object):
with open(json_path, 'rb') as f:
self.data = simplejson.load(f)
self.code_path = self.storage.code.path
self.code = self.storage.code.load()
for protocol in self.data['protocols']:
for _set in protocol['sets']:
......@@ -274,6 +277,58 @@ class Database(object):
return self._name or '__unnamed_database__'
@name.setter
def name(self, value):
self._name = value
self.storage = Storage(self.prefix, value)
@property
def description(self):
"""The short description for this object"""
return self.data.get('description', None)
@description.setter
def description(self, value):
"""Sets the short description for this object"""
self.data['description'] = value
@property
def documentation(self):
"""The full-length description for this object"""
if not self._name:
raise RuntimeError("database has no name")
if self.storage.doc.exists():
return self.storage.doc.load()
return None
@documentation.setter
def documentation(self, value):
"""Sets the full-length description for this object"""
if not self._name:
raise RuntimeError("database has no name")
if hasattr(value, 'read'):
self.storage.doc.save(value.read())
else:
self.storage.doc.save(value)
def hash(self):
"""Returns the hexadecimal hash for its declaration"""
if not self._name:
raise RuntimeError("database has no name")
return self.storage.hash()
@property
def schema_version(self):
"""Returns the schema version"""
......@@ -375,6 +430,88 @@ class Database(object):
self.data['root_folder'], exc)
def json_dumps(self, indent=4):
"""Dumps the JSON declaration of this object in a string
Parameters:
indent (int): The number of indentation spaces at every indentation level
Returns:
str: The JSON representation for this object
"""
return simplejson.dumps(self.data, indent=indent,
cls=utils.NumpyJSONEncoder)
def __str__(self):
return self.json_dumps()
def write(self, storage=None):
"""Writes contents to prefix location
Parameters:
storage (Storage, optional): If you pass a new storage, then this object
will be written to that storage point rather than its default.
"""
if storage is None:
if not self._name:
raise RuntimeError("database has no name")
storage = self.storage #overwrite
storage.save(str(self), self.code, self.description)
def export(self, prefix):
"""Recursively exports itself into another prefix
Dataformats associated are also exported recursively
Parameters:
prefix (str): A path to a prefix that must different then my own.
Returns:
None
Raises:
RuntimeError: If prefix and self.prefix point to the same directory.
"""
if not self._name:
raise RuntimeError("database has no name")
if not self.valid:
raise RuntimeError("database is not valid")
if isinstance(prefix, six.string_types):
prefix = utils.Prefix(prefix)
if prefix.paths[0] in self.prefix.paths:
raise RuntimeError("Cannot export database to the same prefix (%s in " \
"%s)" % (prefix.paths[0], self.prefix.paths))
for k in self.dataformats.values():
k.export(prefix)
self.write(Storage(prefix, self.name))
#----------------------------------------------------------
......
......@@ -95,7 +95,7 @@ class DBExecutor(object):
"""
def __init__(self, address, prefix, cache_root, data, dataformat_cache=None,
def __init__(self, message_handler, prefix, cache_root, data, dataformat_cache=None,
database_cache=None):
# Initialisations
......@@ -106,6 +106,7 @@ class DBExecutor(object):
self.data = None
self.message_handler = None
self.data_sources = {}
self.message_handler = message_handler
# Temporary caches, if the user has not set them, for performance
database_cache = database_cache if database_cache is not None else {}
......@@ -176,8 +177,7 @@ class DBExecutor(object):
self.data_sources[name] = view.data_sources[details['output']]
# Create the message handler
self.message_handler = MessageHandler(address, data_sources=self.data_sources)
self.message_handler.set_data_sources(self.data_sources)
def process(self):
......@@ -197,7 +197,6 @@ class DBExecutor(object):
def wait(self):
self.message_handler.join()
self.message_handler.destroy()
self.message_handler = None
......
......@@ -64,16 +64,16 @@ class MessageHandler(threading.Thread):
if len(self.address.split(':')) == 2:
port = self.socket.bind_to_random_port(self.address, min_port=50000)
self.address += ':%d' % port
logger.debug("zmq server bound to '%s'", self.address)
else:
self.socket.connect(self.address)
logger.debug("connected to '%s'", self.address)
self.socket.bind(self.address)
logger.debug("zmq server bound to '%s'", self.address)
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
# Initialisations
self.data_sources = data_sources
self.data_sources = None
self.system_error = ''
self.user_error = ''
self.statistics = {}
......@@ -86,23 +86,30 @@ class MessageHandler(threading.Thread):
err = self.error,
)
if self.data_sources is not None:
self.callbacks.update(dict(
ifo = self.infos,
get = self.get_data,
))
if data_sources is not None:
self.set_data_sources(data_sources)
def destroy(self):
self.socket.setsockopt(zmq.LINGER, 0)
self.socket.close()
self.context.destroy()
logger.debug("0MQ client finished")
def __str__(self):
return 'MessageHandler(%s)' % self.address
def set_data_sources(self, data_sources):
self.data_sources = data_sources
self.callbacks.update(dict(
ifo = self.infos,
get = self.get_data,
))
def run(self):
logger.debug("0MQ server thread started")
......
......@@ -29,14 +29,15 @@
"""Executes some database views. (%(version)s)
usage:
%(prog)s [--debug] <addr> <dir>
%(prog)s [--debug] <addr> <dir> <cache>
%(prog)s (--help)
%(prog)s (--version)
arguments:
<addr> Address of the server for I/O requests
<addr> Listen for incoming request on this address ('host:port')
<dir> Directory containing all configuration required to run the views
<cache> Path to the cache
options:
......@@ -58,6 +59,7 @@ import stat
import zmq
from ..dbexecution import DBExecutor
from ..message_handler import MessageHandler
class UserError(Exception):
......@@ -72,38 +74,6 @@ class UserError(Exception):
#----------------------------------------------------------
def send_error(logger, socket, tp, message):
"""Sends a user (usr) or system (sys) error message to the infrastructure"""
logger.debug('send: (err) error')
socket.send('err', zmq.SNDMORE)
socket.send(tp, zmq.SNDMORE)
logger.debug('send: """%s"""' % message.rstrip())
socket.send(message)
poller = zmq.Poller()
poller.register(socket, zmq.POLLIN)
this_try = 1
max_tries = 5
timeout = 1000 #ms
while this_try <= max_tries:
socks = dict(poller.poll(timeout)) #blocks here, for 5 seconds at most
if socket in socks and socks[socket] == zmq.POLLIN:
answer = socket.recv() #ack
logger.debug('recv: %s', answer)
break
logger.warn('(try %d) waited %d ms for "ack" from server',
this_try, timeout)
this_try += 1
if this_try > max_tries:
logger.error('could not send error message to server')
logger.error('stopping 0MQ client anyway')
#----------------------------------------------------------
def process_traceback(tb, prefix):
import traceback
......@@ -122,14 +92,23 @@ def process_traceback(tb, prefix):
#----------------------------------------------------------
def main():
def main(arguments=None):
# Parse the command-line arguments
if arguments is None:
arguments = sys.argv[1:]
package = __name__.rsplit('.', 2)[0]
version = package + ' v' + \
__import__('pkg_resources').require(package)[0].version
prog = os.path.basename(sys.argv[0])
args = docopt.docopt(__doc__ % dict(prog=prog, version=version),
version=version)
args = docopt.docopt(
__doc__ % dict(prog=prog, version=version),
argv=arguments,
version=version
)
# Setup the logging system
......@@ -151,6 +130,10 @@ def main():
logger = logging.getLogger(__name__)
# Create the message handler
message_handler = MessageHandler(args['<addr>'])
# If necessary, change to another user (with less privileges, but has access
# to the databases)
with open(os.path.join(args['<dir>'], 'configuration.json'), 'r') as f:
......@@ -185,18 +168,11 @@ def main():
os.setuid(cfg['datasets_uid'])
except:
import traceback
send_error(logger, socket, 'sys', traceback.format_exc())
message_handler.send_error(traceback.format_exc(), 'sys')
message_handler.destroy()
return 1
# Creates the 0MQ socket for communication with BEAT
context = zmq.Context()
socket = context.socket(zmq.PAIR)
address = args['<addr>']
socket.connect(address)
logger.debug("zmq client connected to `%s'", address)
try:
# Check the dir
......@@ -208,9 +184,12 @@ def main():
database_cache = {}
try:
dbexecutor = DBExecutor(os.path.join(args['<dir>'], 'prefix'),
os.path.join(args['<dir>'], 'configuration.json'),
dataformat_cache, database_cache)
dbexecutor = DBExecutor(message_handler,
os.path.join(args['<dir>'], 'prefix'),
args['<cache>'],
os.path.join(args['<dir>'], 'configuration.json'),
dataformat_cache,
database_cache)
except (MemoryError):
raise
except Exception as e:
......@@ -222,9 +201,8 @@ def main():
# Execute the code
try:
with dbexecutor:
dbexecutor.process(context, socket)
dbexecutor.wait()
dbexecutor.process()
dbexecutor.wait()
except (MemoryError):
raise
except Exception as e:
......@@ -236,28 +214,27 @@ def main():
except UserError as e:
msg = str(e).decode('string_escape').strip("'")
send_error(logger, socket, 'usr', msg)
message_handler.send_error(msg, 'usr')
message_handler.destroy()
return 1
except MemoryError as e:
# Say something meaningful to the user
msg = "The user process for this block ran out of memory. We " \
"suggest you optimise your code to reduce memory usage or, " \
"if this is not an option, choose an appropriate processing " \
"queue with enough memory."
send_error(logger, socket, 'usr', msg)
"suggest you optimise your code to reduce memory usage or, " \
"if this is not an option, choose an appropriate processing " \
"queue with enough memory."
message_handler.send_error(msg, 'usr')
message_handler.destroy()
return 1
except Exception as e:
import traceback
send_error(logger, socket, 'sys', traceback.format_exc())
message_handler.send_error(traceback.format_exc(), 'sys')
message_handler.destroy()
return 1
finally:
socket.setsockopt(zmq.LINGER, 0)
socket.close()
context.term()
logger.debug("0MQ client finished")
message_handler.destroy()
return 0
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.core module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
# Tests for experiment execution
import os
import logging
logger = logging.getLogger(__name__)
import unittest
import simplejson
import multiprocessing
import Queue
import tempfile
import shutil
import zmq
from time import time
from time import sleep
from ..scripts import databases_provider
from ..database import Database
from ..data import RemoteDataSource
from ..data import RemoteException
from . import prefix
from . import tmp_prefix
#----------------------------------------------------------
CONFIGURATION = {
'queue': 'queue',
'inputs': {
'in_data': {
'set': 'double',
'protocol': 'double',
'database': 'integers_db/1',
'output': 'a',
'path': 'ec/89/e5/6e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55',
'endpoint': 'a',
'hash': 'ec89e56e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55',
'channel': 'integers'
}
},
'algorithm': 'user/integers_echo/1',
'parameters': {},
'environment': {
'name': 'Python 2.7',
'version': '1.2.0'
},
'outputs': {
'out_data': {
'path': '20/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
'endpoint': 'out_data',
'hash': '2061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
'channel': 'integers'
}
},
'nb_slots': 1,
'channel': 'integers'
}
#----------------------------------------------------------
CONFIGURATION_ERROR = {
'queue': 'queue',
'inputs': {
'in_data': {
'set': 'get_crashes',
'protocol': 'protocol',
'database': 'crash/1',
'output': 'out',
'path': 'ec/89/e5/6e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55',
'endpoint': 'in',
'hash': 'ec89e56e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55',
'channel': 'set'
}
},
'algorithm': 'user/integers_echo/1',
'parameters': {},
'environment': {
'name': 'Python 2.7',
'version': '1.2.0'