Commit 29e6ab5e authored by Philip ABBET's avatar Philip ABBET

Refactoring: Databases, DBExecutor, MessageHandler, DataSources

parent dce310ab
This diff is collapsed.
......@@ -35,6 +35,7 @@ import six
import simplejson
import itertools
import numpy as np
from collections import namedtuple
from . import loader
from . import utils
......@@ -73,7 +74,7 @@ class Storage(utils.CodeStorage):
#----------------------------------------------------------
class View(object):
class Runner(object):
'''A special loader class for database views, with specialized methods
Parameters:
......@@ -98,9 +99,7 @@ class View(object):
'''
def __init__(self, module, definition, prefix, root_folder, exc=None,
*args, **kwargs):
def __init__(self, module, definition, prefix, root_folder, exc=None):
try:
class_ = getattr(module, definition['view'])
......@@ -109,94 +108,73 @@ class View(object):
type, value, traceback = sys.exc_info()
six.reraise(exc, exc(value), traceback)
else:
raise #just re-raise the user exception
raise # just re-raise the user exception
self.obj = loader.run(class_, '__new__', exc, *args, **kwargs)
self.ready = False
self.prefix = prefix
self.root_folder = root_folder
self.definition = definition
self.exc = exc or RuntimeError
self.outputs = None
self.obj = loader.run(class_, '__new__', exc)
self.ready = False
self.prefix = prefix
self.root_folder = root_folder
self.definition = definition
self.exc = exc or RuntimeError
self.data_sources = None
def prepare_outputs(self):
'''Prepares the outputs of the dataset'''
def index(self, filename):
'''Index the content of the view'''
from .outputs import Output, OutputList
from .data import MemoryDataSink
from .dataformat import DataFormat
parameters = self.definition.get('parameters', {})
# create the stock outputs for this dataset, so data is dumped
# on a in-memory sink
self.outputs = OutputList()
for out_name, out_format in self.definition.get('outputs', {}).items():
data_sink = MemoryDataSink()
data_sink.dataformat = DataFormat(self.prefix, out_format)
data_sink.setup([])
self.outputs.add(Output(out_name, data_sink, dataset_output=True))
objs = loader.run(self.obj, 'index', self.exc, self.root_folder, parameters)
if not isinstance(objs, list):
raise self.exc("index() didn't return a list")
def setup(self, *args, **kwargs):
'''Sets up the view'''
if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))
kwargs.setdefault('root_folder', self.root_folder)
kwargs.setdefault('parameters', self.definition.get('parameters', {}))
with open(filename, 'wb') as f:
simplejson.dump(objs, f)
if 'outputs' not in kwargs:
kwargs['outputs'] = self.outputs
else:
self.outputs = kwargs['outputs'] #record outputs nevertheless
self.ready = loader.run(self.obj, 'setup', self.exc, *args, **kwargs)
def setup(self, filename, start_index=None, end_index=None, pack=True):
'''Sets up the view'''
if not self.ready:
raise self.exc("unknow setup failure")
if self.ready:
return
return self.ready
with open(filename, 'rb') as f:
objs = simplejson.load(f)
Entry = namedtuple('Entry', sorted(objs[0].keys()))
objs = [ Entry(**x) for x in objs ]
def input_group(self, name='default', exclude_outputs=[]):
'''A memory-source input group matching the outputs from the view'''
parameters = self.definition.get('parameters', {})
if not self.ready:
raise self.exc("database view not yet setup")
loader.run(self.obj, 'setup', self.exc, self.root_folder, parameters,
objs, start_index=start_index, end_index=end_index)
from .data import MemoryDataSource
from .outputs import SynchronizationListener
from .inputs import Input, InputGroup
# Setup the inputs
synchronization_listener = SynchronizationListener()
input_group = InputGroup(name,
synchronization_listener=synchronization_listener,
restricted_access=False)
# Create data sources for the outputs
from .data import DatabaseOutputDataSource
from .dataformat import DataFormat
for output in self.outputs:
if output.name in exclude_outputs: continue
data_source = MemoryDataSource(self.done, next_callback=self.next)
output.data_sink.data_sources.append(data_source)
input_group.add(Input(output.name,
output.data_sink.dataformat, data_source))
self.data_sources = {}
for output_name, output_format in self.definition.get('outputs', {}).items():
data_source = DatabaseOutputDataSource()
data_source.setup(self, output_name, output_format, self.prefix,
start_index=start_index, end_index=end_index, pack=pack)
self.data_sources[output_name] = data_source
return input_group
self.ready = True
def done(self, *args, **kwargs):
'''Checks if the view is done'''
def get(self, output, index):
'''Returns the data of the provided output at the provided index'''
if not self.ready:
raise self.exc("database view not yet setup")
return loader.run(self.obj, 'done', self.exc, *args, **kwargs)
raise self.exc("Database view not yet setup")
def next(self, *args, **kwargs):
'''Runs through the next data chunk'''
if not self.ready:
raise self.exc("database view not yet setup")
return loader.run(self.obj, 'next', self.exc, *args, **kwargs)
return loader.run(self.obj, 'get', self.exc, output, index)
def __getattr__(self, key):
......@@ -204,6 +182,10 @@ class View(object):
return getattr(self.obj, key)
def objects(self):
return self.obj.objs
#----------------------------------------------------------
......@@ -368,7 +350,9 @@ class Database(object):
if not self.valid:
message = "cannot load view for set `%s' of protocol `%s' " \
"from invalid database (%s)" % (protocol, name, self.name)
if exc: raise exc(message)
if exc:
raise exc(message)
raise RuntimeError(message)
# loads the module only once through the lifetime of the database object
......@@ -383,8 +367,73 @@ class Database(object):
else:
raise #just re-raise the user exception
return View(self._module, self.set(protocol, name), self.prefix,
self.data['root_folder'], exc)
return Runner(self._module, self.set(protocol, name), self.prefix,
self.data['root_folder'], exc)
#----------------------------------------------------------
class View(object):
def index(self, root_folder, parameters):
"""Returns a list of (named) tuples describing the data provided by the view.
The ordering of values inside the tuples is free, but it is expected
that the list is ordered in a consistent manner (ie. all train images of
person A, then all train images of person B, ...).
For instance, assuming a view providing that kind of data:
----------- ----------- ----------- ----------- ----------- -----------
| image | | image | | image | | image | | image | | image |
----------- ----------- ----------- ----------- ----------- -----------
----------- ----------- ----------- ----------- ----------- -----------
| file_id | | file_id | | file_id | | file_id | | file_id | | file_id |
----------- ----------- ----------- ----------- ----------- -----------
----------------------------------- -----------------------------------
| client_id | | client_id |
----------------------------------- -----------------------------------
a list like the following should be generated:
[
(client_id=1, file_id=1, image=filename1),
(client_id=1, file_id=2, image=filename2),
(client_id=1, file_id=3, image=filename3),
(client_id=2, file_id=4, image=filename4),
(client_id=2, file_id=5, image=filename5),
(client_id=2, file_id=6, image=filename6),
...
]
DO NOT store images, sound files or data loadable from a file in the list!
Store the path of the file to load instead.
"""
raise NotImplementedError
def setup(self, root_folder, parameters, objs, start_index=None, end_index=None):
# Initialisations
self.root_folder = root_folder
self.parameters = parameters
self.objs = objs
# Determine the range of indices that must be provided
self.start_index = start_index if start_index is not None else 0
self.end_index = end_index if end_index is not None else len(self.objs) - 1
self.objs = self.objs[self.start_index : self.end_index + 1]
def get(self, output, index):
"""Returns the data of the provided output at the provided index in the list
of (named) tuples describing the data provided by the view (accessible at
self.objs)"""
raise NotImplementedError
#----------------------------------------------------------
......
......@@ -29,11 +29,6 @@
'''Execution utilities'''
import os
import sys
import glob
import errno
import tempfile
import subprocess
import logging
logger = logging.getLogger(__name__)
......@@ -41,11 +36,8 @@ logger = logging.getLogger(__name__)
import simplejson
# from . import schema
from . import database
from . import inputs
from . import outputs
from . import data
from . import message_handler
from .database import Database
from .message_handler import MessageHandler
class DBExecutor(object):
......@@ -102,38 +94,24 @@ class DBExecutor(object):
"""
def __init__(self, prefix, data, dataformat_cache=None, database_cache=None):
def __init__(self, address, prefix, cache_root, data, dataformat_cache=None,
database_cache=None):
# Initialisations
self.prefix = prefix
# some attributes
self.databases = {}
self.views = {}
self.input_list = None
self.data_sources = []
self.handler = None
self.errors = []
self.data = None
self.message_handler = None
self.data_sources = {}
# temporary caches, if the user has not set them, for performance
# Temporary caches, if the user has not set them, for performance
database_cache = database_cache if database_cache is not None else {}
self.dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
self._load(data, database_cache)
def _load(self, data, database_cache):
"""Loads the block execution information"""
# reset
self.data = None
self.errors = []
self.databases = {}
self.views = {}
self.input_list = None
self.data_sources = []
if not isinstance(data, dict): #user has passed a file name
# Load the data
if not isinstance(data, dict): # User has passed a file name
if not os.path.exists(data):
self.errors.append('File not found: %s' % data)
return
......@@ -147,135 +125,79 @@ class DBExecutor(object):
# self.data, self.errors = schema.validate('execution', data)
# if self.errors: return #don't proceed with the rest of validation
# load databases
# Load the databases
for name, details in self.data['inputs'].items():
if 'database' in details:
if details['database'] not in self.databases:
if 'database' not in details:
continue
if details['database'] in database_cache: #reuse
db = database_cache[details['database']]
else: #load it
db = database.Database(self.prefix, details['database'],
self.dataformat_cache)
database_cache[db.name] = db
# Load the database
if details['database'] not in self.databases:
self.databases[details['database']] = db
if details['database'] in database_cache: #reuse
db = database_cache[details['database']]
else: #load it
db = Database(self.prefix, details['database'],
self.dataformat_cache)
database_cache[db.name] = db
if not db.valid:
self.errors += db.errors
continue
self.databases[details['database']] = db
if not db.valid:
# do not add errors again
continue
# create and load the required views
key = (details['database'], details['protocol'], details['set'])
if key not in self.views:
view = self.databases[details['database']].view(details['protocol'],
details['set'])
if details['channel'] == self.data['channel']: #synchronized
start_index, end_index = self.data.get('range', (None, None))
else:
start_index, end_index = (None, None)
view.prepare_outputs()
self.views[key] = (view, start_index, end_index)
def __enter__(self):
"""Prepares inputs and outputs for the processing task
Raises:
IOError: in case something cannot be properly setup
"""
self._prepare_inputs()
# The setup() of a database view may call isConnected() on an input
# to set the index at the right location when parallelization is enabled.
# This is why setup() should be called after initialized the inputs.
for key, (view, start_index, end_index) in self.views.items():
if (start_index is None) and (end_index is None):
status = view.setup()
self.errors += db.errors
else:
status = view.setup(force_start_index=start_index,
force_end_index=end_index)
db = self.databases[details['database']]
if not status:
raise RuntimeError("Could not setup database view `%s'" % key)
if not db.valid:
continue
return self
# Create and load the required views
key = (details['database'], details['protocol'], details['set'])
if key not in self.views:
view = db.view(details['protocol'], details['set'])
if details['channel'] == self.data['channel']: #synchronized
start_index, end_index = self.data.get('range', (None, None))
else:
start_index, end_index = (None, None)
def __exit__(self, exc_type, exc_value, traceback):
"""Closes all sinks and disconnects inputs and outputs
"""
self.input_list = None
self.data_sources = []
view.setup(os.path.join(cache_root, details['path']),
start_index=start_index, end_index=end_index)
self.views[key] = view
def _prepare_inputs(self):
"""Prepares all input required by the execution."""
self.input_list = inputs.InputList()
# This is used for parallelization purposes
start_index, end_index = self.data.get('range', (None, None))
# Create the data sources
for name, details in self.data['inputs'].items():
if 'database' not in details:
continue
if 'database' in details: #it is a dataset input
view_key = (details['database'], details['protocol'], details['set'])
view = self.views[view_key]
view_key = (details['database'], details['protocol'], details['set'])
view = self.views[view_key][0]
self.data_sources[name] = view.data_sources[details['output']]
data_source = data.MemoryDataSource(view.done, next_callback=view.next)
self.data_sources.append(data_source)
output = view.outputs[details['output']]
# Create the message handler
self.message_handler = MessageHandler(address, data_sources=self.data_sources)
# if it's a synchronized channel, makes the output start at the right
# index, otherwise, it gets lost
if start_index is not None and \
details['channel'] == self.data['channel']:
output.last_written_data_index = start_index - 1
output.data_sink.data_sources.append(data_source)
# Synchronization bits
group = self.input_list.group(details['channel'])
if group is None:
group = inputs.InputGroup(
details['channel'],
synchronization_listener=outputs.SynchronizationListener(),
restricted_access=(details['channel'] == self.data['channel'])
)
self.input_list.add(group)
def process(self):
self.message_handler.start()
input_db = self.databases[details['database']]
input_dataformat_name = input_db.set(details['protocol'], details['set'])['outputs'][details['output']]
group.add(inputs.Input(name, self.dataformat_cache[input_dataformat_name], data_source))
def process(self, address):
self.handler = message_handler.MessageHandler(address, inputs=self.input_list)
self.handler.start()
@property
def address(self):
return self.message_handler.address
@property
def valid(self):
"""A boolean that indicates if this executor is valid or not"""
return not bool(self.errors)
def wait(self):
self.handler.join()
self.handler.destroy()
self.handler = None
self.message_handler.join()
self.message_handler.destroy()
self.message_handler = None
def __str__(self):
......
......@@ -3,7 +3,7 @@
###############################################################################
# #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.backend.python module of the BEAT platform. #
......@@ -37,6 +37,9 @@ import six
import os
#----------------------------------------------------------
def _sha256(s):
"""A python2/3 replacement for :py:func:`haslib.sha256`"""
......@@ -47,6 +50,8 @@ def _sha256(s):
return hashlib.sha256(s).hexdigest()
#----------------------------------------------------------
def _stringify(dictionary):
names = sorted(dictionary.keys())
......@@ -63,12 +68,30 @@ def _stringify(dictionary):
return converted_dictionary
#----------------------------------------------------------
def _compact(text):
return text.replace(' ', '').replace('\n', '')
#----------------------------------------------------------
def toPath(hash, suffix='.data'):
return os.path.join(hash[0:2], hash[2:4], hash[4:6], hash[6:] + suffix)
#----------------------------------------------------------
def toUserPath(username):
hash = _sha256(username)
return os.path.join(hash[0:2], hash[2:4], username)
#----------------------------------------------------------
def hash(dictionary_or_string):
if isinstance(dictionary_or_string, dict):
......@@ -77,6 +100,8 @@ def hash(dictionary_or_string):
return _sha256(dictionary_or_string)
#----------------------------------------------------------
def hashJSON(contents, description):
"""Hashes the pre-loaded JSON object using :py:func:`hashlib.sha256`
......@@ -91,6 +116,8 @@ def hashJSON(contents, description):
return hashlib.sha256(contents).hexdigest()
#----------------------------------------------------------
def hashJSONFile(path, description):
"""Hashes the JSON file contents using :py:func:`hashlib.sha256`
......@@ -107,6 +134,8 @@ def hashJSONFile(path, description):
return hashFileContents(path)
#----------------------------------------------------------
def hashFileContents(path):
"""Hashes the file contents using :py:func:`hashlib.sha256`."""
......@@ -117,3 +146,15 @@ def hashFileContents(path):
sha256.update(chunk)
return sha256.hexdigest()
#----------------------------------------------------------
def hashDataset(database_name, protocol_name, set_name):
s = _compact("""{
"database": "%s",
"protocol": "%s",
"set": "%s"
}""") % (database_name, protocol_name, set_name)
return hash(s)
......@@ -32,9 +32,9 @@ import errno
import logging
logger = logging.getLogger(__name__)
from .data import MemoryDataSource
from .data import MemoryLegacyDataSource
from .data import CachedLegacyDataSource
from .data import CachedDataSource
from .data import CachedFileLoader
from .data import CachedDataSink
from .data import getAllFilenames
from .data_loaders import DataLoaderList
......@@ -103,7 +103,7 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
def _create_local_input(details):
data_source = CachedDataSource()
data_source = CachedLegacyDataSource()
data_sources.append(data_source)
filename = os.path.join(cache_root, details['path'] + '.data')
......@@ -144,7 +144,7 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
logger.debug("Data loader created: group='%s'" % details['channel'])
cached_file = CachedFileLoader()
cached_file = CachedDataSource()
result = cached_file.setup(
filename=filename,
prefix=prefix,
......@@ -193,7 +193,7 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
view = views[channel]
# Creation of the input
data_source = MemoryDataSource(view.done, next_callback=view.next)
data_source = MemoryLegacyDataSource(view.done, next_callback=view.next)
output = view.outputs[details['output']]
output.data_sink.data_sources.append(data_source)
......
......@@ -36,6 +36,7 @@ import six
import zmq
from .data import mixDataIndices
from .data import RemoteException
#----------------------------------------------------------
......@@ -136,13 +137,13 @@ class Input(BaseInput):
data_format (str): Data format accepted by the input
data_source (beat.core.platform.data.DataSource): Source of data to be used
data_source (beat.core.platform.data.LegacyDataSource): Source of data to be used
by the input
Attributes:
data_source (beat.core.data.DataSource): Source of data used by the output
data_source (beat.core.data.LegacyDataSource): Source of data used by the output
"""
......@@ -180,22 +181,6 @@ class Input(BaseInput):
#----------------------------------------------------------
class RemoteException(Exception):
def __init__(self, kind, message):
super(RemoteException, self).__init__()
if kind == 'sys':
self.system_error = message
self.user_error = ''
else:
self.system_error = ''
self.user_error = message
#----------------------------------------------------------
def process_error(socket):
kind = socket.recv()
message = socket.recv()
......