Commit 29e6ab5e authored by Philip ABBET's avatar Philip ABBET

Refactoring: Databases, DBExecutor, MessageHandler, DataSources

parent dce310ab
This diff is collapsed.
...@@ -35,6 +35,7 @@ import six ...@@ -35,6 +35,7 @@ import six
import simplejson import simplejson
import itertools import itertools
import numpy as np import numpy as np
from collections import namedtuple
from . import loader from . import loader
from . import utils from . import utils
...@@ -73,7 +74,7 @@ class Storage(utils.CodeStorage): ...@@ -73,7 +74,7 @@ class Storage(utils.CodeStorage):
#---------------------------------------------------------- #----------------------------------------------------------
class View(object): class Runner(object):
'''A special loader class for database views, with specialized methods '''A special loader class for database views, with specialized methods
Parameters: Parameters:
...@@ -98,9 +99,7 @@ class View(object): ...@@ -98,9 +99,7 @@ class View(object):
''' '''
def __init__(self, module, definition, prefix, root_folder, exc=None):
def __init__(self, module, definition, prefix, root_folder, exc=None,
*args, **kwargs):
try: try:
class_ = getattr(module, definition['view']) class_ = getattr(module, definition['view'])
...@@ -109,94 +108,73 @@ class View(object): ...@@ -109,94 +108,73 @@ class View(object):
type, value, traceback = sys.exc_info() type, value, traceback = sys.exc_info()
six.reraise(exc, exc(value), traceback) six.reraise(exc, exc(value), traceback)
else: else:
raise #just re-raise the user exception raise # just re-raise the user exception
self.obj = loader.run(class_, '__new__', exc, *args, **kwargs) self.obj = loader.run(class_, '__new__', exc)
self.ready = False self.ready = False
self.prefix = prefix self.prefix = prefix
self.root_folder = root_folder self.root_folder = root_folder
self.definition = definition self.definition = definition
self.exc = exc or RuntimeError self.exc = exc or RuntimeError
self.outputs = None self.data_sources = None
def prepare_outputs(self): def index(self, filename):
'''Prepares the outputs of the dataset''' '''Index the content of the view'''
from .outputs import Output, OutputList parameters = self.definition.get('parameters', {})
from .data import MemoryDataSink
from .dataformat import DataFormat
# create the stock outputs for this dataset, so data is dumped objs = loader.run(self.obj, 'index', self.exc, self.root_folder, parameters)
# on a in-memory sink
self.outputs = OutputList()
for out_name, out_format in self.definition.get('outputs', {}).items():
data_sink = MemoryDataSink()
data_sink.dataformat = DataFormat(self.prefix, out_format)
data_sink.setup([])
self.outputs.add(Output(out_name, data_sink, dataset_output=True))
if not isinstance(objs, list):
raise self.exc("index() didn't return a list")
def setup(self, *args, **kwargs): if not os.path.exists(os.path.dirname(filename)):
'''Sets up the view''' os.makedirs(os.path.dirname(filename))
kwargs.setdefault('root_folder', self.root_folder) with open(filename, 'wb') as f:
kwargs.setdefault('parameters', self.definition.get('parameters', {})) simplejson.dump(objs, f)
if 'outputs' not in kwargs:
kwargs['outputs'] = self.outputs
else:
self.outputs = kwargs['outputs'] #record outputs nevertheless
self.ready = loader.run(self.obj, 'setup', self.exc, *args, **kwargs) def setup(self, filename, start_index=None, end_index=None, pack=True):
'''Sets up the view'''
if not self.ready: if self.ready:
raise self.exc("unknow setup failure") return
return self.ready with open(filename, 'rb') as f:
objs = simplejson.load(f)
Entry = namedtuple('Entry', sorted(objs[0].keys()))
objs = [ Entry(**x) for x in objs ]
def input_group(self, name='default', exclude_outputs=[]): parameters = self.definition.get('parameters', {})
'''A memory-source input group matching the outputs from the view'''
if not self.ready: loader.run(self.obj, 'setup', self.exc, self.root_folder, parameters,
raise self.exc("database view not yet setup") objs, start_index=start_index, end_index=end_index)
from .data import MemoryDataSource
from .outputs import SynchronizationListener
from .inputs import Input, InputGroup
# Setup the inputs # Create data sources for the outputs
synchronization_listener = SynchronizationListener() from .data import DatabaseOutputDataSource
input_group = InputGroup(name, from .dataformat import DataFormat
synchronization_listener=synchronization_listener,
restricted_access=False)
for output in self.outputs: self.data_sources = {}
if output.name in exclude_outputs: continue for output_name, output_format in self.definition.get('outputs', {}).items():
data_source = MemoryDataSource(self.done, next_callback=self.next) data_source = DatabaseOutputDataSource()
output.data_sink.data_sources.append(data_source) data_source.setup(self, output_name, output_format, self.prefix,
input_group.add(Input(output.name, start_index=start_index, end_index=end_index, pack=pack)
output.data_sink.dataformat, data_source)) self.data_sources[output_name] = data_source
return input_group self.ready = True
def done(self, *args, **kwargs): def get(self, output, index):
'''Checks if the view is done''' '''Returns the data of the provided output at the provided index'''
if not self.ready: if not self.ready:
raise self.exc("database view not yet setup") raise self.exc("Database view not yet setup")
return loader.run(self.obj, 'done', self.exc, *args, **kwargs)
def next(self, *args, **kwargs): return loader.run(self.obj, 'get', self.exc, output, index)
'''Runs through the next data chunk'''
if not self.ready:
raise self.exc("database view not yet setup")
return loader.run(self.obj, 'next', self.exc, *args, **kwargs)
def __getattr__(self, key): def __getattr__(self, key):
...@@ -204,6 +182,10 @@ class View(object): ...@@ -204,6 +182,10 @@ class View(object):
return getattr(self.obj, key) return getattr(self.obj, key)
def objects(self):
return self.obj.objs
#---------------------------------------------------------- #----------------------------------------------------------
...@@ -368,7 +350,9 @@ class Database(object): ...@@ -368,7 +350,9 @@ class Database(object):
if not self.valid: if not self.valid:
message = "cannot load view for set `%s' of protocol `%s' " \ message = "cannot load view for set `%s' of protocol `%s' " \
"from invalid database (%s)" % (protocol, name, self.name) "from invalid database (%s)" % (protocol, name, self.name)
if exc: raise exc(message) if exc:
raise exc(message)
raise RuntimeError(message) raise RuntimeError(message)
# loads the module only once through the lifetime of the database object # loads the module only once through the lifetime of the database object
...@@ -383,8 +367,73 @@ class Database(object): ...@@ -383,8 +367,73 @@ class Database(object):
else: else:
raise #just re-raise the user exception raise #just re-raise the user exception
return View(self._module, self.set(protocol, name), self.prefix, return Runner(self._module, self.set(protocol, name), self.prefix,
self.data['root_folder'], exc) self.data['root_folder'], exc)
#----------------------------------------------------------
class View(object):
def index(self, root_folder, parameters):
"""Returns a list of (named) tuples describing the data provided by the view.
The ordering of values inside the tuples is free, but it is expected
that the list is ordered in a consistent manner (ie. all train images of
person A, then all train images of person B, ...).
For instance, assuming a view providing that kind of data:
----------- ----------- ----------- ----------- ----------- -----------
| image | | image | | image | | image | | image | | image |
----------- ----------- ----------- ----------- ----------- -----------
----------- ----------- ----------- ----------- ----------- -----------
| file_id | | file_id | | file_id | | file_id | | file_id | | file_id |
----------- ----------- ----------- ----------- ----------- -----------
----------------------------------- -----------------------------------
| client_id | | client_id |
----------------------------------- -----------------------------------
a list like the following should be generated:
[
(client_id=1, file_id=1, image=filename1),
(client_id=1, file_id=2, image=filename2),
(client_id=1, file_id=3, image=filename3),
(client_id=2, file_id=4, image=filename4),
(client_id=2, file_id=5, image=filename5),
(client_id=2, file_id=6, image=filename6),
...
]
DO NOT store images, sound files or data loadable from a file in the list!
Store the path of the file to load instead.
"""
raise NotImplementedError
def setup(self, root_folder, parameters, objs, start_index=None, end_index=None):
# Initialisations
self.root_folder = root_folder
self.parameters = parameters
self.objs = objs
# Determine the range of indices that must be provided
self.start_index = start_index if start_index is not None else 0
self.end_index = end_index if end_index is not None else len(self.objs) - 1
self.objs = self.objs[self.start_index : self.end_index + 1]
def get(self, output, index):
"""Returns the data of the provided output at the provided index in the list
of (named) tuples describing the data provided by the view (accessible at
self.objs)"""
raise NotImplementedError
#---------------------------------------------------------- #----------------------------------------------------------
......
...@@ -29,11 +29,6 @@ ...@@ -29,11 +29,6 @@
'''Execution utilities''' '''Execution utilities'''
import os import os
import sys
import glob
import errno
import tempfile
import subprocess
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -41,11 +36,8 @@ logger = logging.getLogger(__name__) ...@@ -41,11 +36,8 @@ logger = logging.getLogger(__name__)
import simplejson import simplejson
# from . import schema # from . import schema
from . import database from .database import Database
from . import inputs from .message_handler import MessageHandler
from . import outputs
from . import data
from . import message_handler
class DBExecutor(object): class DBExecutor(object):
...@@ -102,38 +94,24 @@ class DBExecutor(object): ...@@ -102,38 +94,24 @@ class DBExecutor(object):
""" """
def __init__(self, prefix, data, dataformat_cache=None, database_cache=None): def __init__(self, address, prefix, cache_root, data, dataformat_cache=None,
database_cache=None):
# Initialisations
self.prefix = prefix self.prefix = prefix
# some attributes
self.databases = {} self.databases = {}
self.views = {} self.views = {}
self.input_list = None
self.data_sources = []
self.handler = None
self.errors = [] self.errors = []
self.data = None self.data = None
self.message_handler = None
self.data_sources = {}
# temporary caches, if the user has not set them, for performance # Temporary caches, if the user has not set them, for performance
database_cache = database_cache if database_cache is not None else {} database_cache = database_cache if database_cache is not None else {}
self.dataformat_cache = dataformat_cache if dataformat_cache is not None else {} self.dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
self._load(data, database_cache) # Load the data
if not isinstance(data, dict): # User has passed a file name
def _load(self, data, database_cache):
"""Loads the block execution information"""
# reset
self.data = None
self.errors = []
self.databases = {}
self.views = {}
self.input_list = None
self.data_sources = []
if not isinstance(data, dict): #user has passed a file name
if not os.path.exists(data): if not os.path.exists(data):
self.errors.append('File not found: %s' % data) self.errors.append('File not found: %s' % data)
return return
...@@ -147,135 +125,79 @@ class DBExecutor(object): ...@@ -147,135 +125,79 @@ class DBExecutor(object):
# self.data, self.errors = schema.validate('execution', data) # self.data, self.errors = schema.validate('execution', data)
# if self.errors: return #don't proceed with the rest of validation # if self.errors: return #don't proceed with the rest of validation
# load databases # Load the databases
for name, details in self.data['inputs'].items(): for name, details in self.data['inputs'].items():
if 'database' in details: if 'database' not in details:
continue
if details['database'] not in self.databases:
if details['database'] in database_cache: #reuse # Load the database
db = database_cache[details['database']] if details['database'] not in self.databases:
else: #load it
db = database.Database(self.prefix, details['database'],
self.dataformat_cache)
database_cache[db.name] = db
self.databases[details['database']] = db if details['database'] in database_cache: #reuse
db = database_cache[details['database']]
else: #load it
db = Database(self.prefix, details['database'],
self.dataformat_cache)
database_cache[db.name] = db
if not db.valid: self.databases[details['database']] = db
self.errors += db.errors
continue
if not db.valid: if not db.valid:
# do not add errors again self.errors += db.errors
continue
# create and load the required views
key = (details['database'], details['protocol'], details['set'])
if key not in self.views:
view = self.databases[details['database']].view(details['protocol'],
details['set'])
if details['channel'] == self.data['channel']: #synchronized
start_index, end_index = self.data.get('range', (None, None))
else:
start_index, end_index = (None, None)
view.prepare_outputs()
self.views[key] = (view, start_index, end_index)
def __enter__(self):
"""Prepares inputs and outputs for the processing task
Raises:
IOError: in case something cannot be properly setup
"""
self._prepare_inputs()
# The setup() of a database view may call isConnected() on an input
# to set the index at the right location when parallelization is enabled.
# This is why setup() should be called after initialized the inputs.
for key, (view, start_index, end_index) in self.views.items():
if (start_index is None) and (end_index is None):
status = view.setup()
else: else:
status = view.setup(force_start_index=start_index, db = self.databases[details['database']]
force_end_index=end_index)
if not status: if not db.valid:
raise RuntimeError("Could not setup database view `%s'" % key) continue
return self # Create and load the required views
key = (details['database'], details['protocol'], details['set'])
if key not in self.views:
view = db.view(details['protocol'], details['set'])
if details['channel'] == self.data['channel']: #synchronized
start_index, end_index = self.data.get('range', (None, None))
else:
start_index, end_index = (None, None)
def __exit__(self, exc_type, exc_value, traceback): view.setup(os.path.join(cache_root, details['path']),
"""Closes all sinks and disconnects inputs and outputs start_index=start_index, end_index=end_index)
"""
self.input_list = None
self.data_sources = []
self.views[key] = view
def _prepare_inputs(self): # Create the data sources
"""Prepares all input required by the execution."""
self.input_list = inputs.InputList()
# This is used for parallelization purposes
start_index, end_index = self.data.get('range', (None, None))
for name, details in self.data['inputs'].items(): for name, details in self.data['inputs'].items():
if 'database' not in details:
continue
if 'database' in details: #it is a dataset input view_key = (details['database'], details['protocol'], details['set'])
view = self.views[view_key]
view_key = (details['database'], details['protocol'], details['set']) self.data_sources[name] = view.data_sources[details['output']]
view = self.views[view_key][0]
data_source = data.MemoryDataSource(view.done, next_callback=view.next) # Create the message handler
self.data_sources.append(data_source) self.message_handler = MessageHandler(address, data_sources=self.data_sources)
output = view.outputs[details['output']]
# if it's a synchronized channel, makes the output start at the right
# index, otherwise, it gets lost
if start_index is not None and \
details['channel'] == self.data['channel']:
output.last_written_data_index = start_index - 1
output.data_sink.data_sources.append(data_source)
# Synchronization bits def process(self):
group = self.input_list.group(details['channel']) self.message_handler.start()
if group is None:
group = inputs.InputGroup(
details['channel'],
synchronization_listener=outputs.SynchronizationListener(),
restricted_access=(details['channel'] == self.data['channel'])
)
self.input_list.add(group)
input_db = self.databases[details['database']]
input_dataformat_name = input_db.set(details['protocol'], details['set'])['outputs'][details['output']]
group.add(inputs.Input(name, self.dataformat_cache[input_dataformat_name], data_source))
@property
def process(self, address): def address(self):
self.handler = message_handler.MessageHandler(address, inputs=self.input_list) return self.message_handler.address
self.handler.start()
@property @property
def valid(self): def valid(self):
"""A boolean that indicates if this executor is valid or not""" """A boolean that indicates if this executor is valid or not"""
return not bool(self.errors) return not bool(self.errors)
def wait(self): def wait(self):
self.handler.join() self.message_handler.join()
self.handler.destroy() self.message_handler.destroy()
self.handler = None self.message_handler = None
def __str__(self): def __str__(self):
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
############################################################################### ###############################################################################
# # # #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ # # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch # # Contact: beat.support@idiap.ch #
# # # #
# This file is part of the beat.backend.python module of the BEAT platform. # # This file is part of the beat.backend.python module of the BEAT platform. #
...@@ -37,6 +37,9 @@ import six ...@@ -37,6 +37,9 @@ import six
import os import os
#----------------------------------------------------------
def _sha256(s): def _sha256(s):
"""A python2/3 replacement for :py:func:`haslib.sha256`""" """A python2/3 replacement for :py:func:`haslib.sha256`"""
...@@ -47,6 +50,8 @@ def _sha256(s): ...@@ -47,6 +50,8 @@ def _sha256(s):
return hashlib.sha256(s).hexdigest() return hashlib.sha256(s).hexdigest()
#----------------------------------------------------------
def _stringify(dictionary): def _stringify(dictionary):
names = sorted(dictionary.keys()) names = sorted(dictionary.keys())
...@@ -63,12 +68,30 @@ def _stringify(dictionary): ...@@ -63,12 +68,30 @@ def _stringify(dictionary):
return converted_dictionary return converted_dictionary
#----------------------------------------------------------
def _compact(text):
return text.replace(' ', '').replace('\n', '')
#----------------------------------------------------------
def toPath(hash, suffix='.data'):
return os.path.join(hash[0:2], hash[2:4], hash[4:6], hash[6:] + suffix)
#----------------------------------------------------------
def toUserPath(username): def toUserPath(username):
hash = _sha256(username) hash = _sha256(username)
return os.path.join(hash[0:2], hash[2:4], username) return os.path.join(hash[0:2], hash[2:4], username)
#----------------------------------------------------------
def hash(dictionary_or_string): def hash(dictionary_or_string):
if isinstance(dictionary_or_string, dict): if isinstance(dictionary_or_string, dict):
...@@ -77,6 +100,8 @@ def hash(dictionary_or_string): ...@@ -77,6 +100,8 @@ def hash(dictionary_or_string):
return _sha256(dictionary_or_string) return _sha256(dictionary_or_string)
#----------------------------------------------------------
def hashJSON(contents, description): def hashJSON(contents, description):
"""Hashes the pre-loaded JSON object using :py:func:`hashlib.sha256` """Hashes the pre-loaded JSON object using :py:func:`hashlib.sha256`
...@@ -91,6 +116,8 @@ def hashJSON(contents, description): ...@@ -91,6 +116,8 @@ def hashJSON(contents, description):
return hashlib.sha256(contents).hexdigest() return hashlib.sha256(contents).hexdigest()