Commit 8577d403 authored by Philip ABBET's avatar Philip ABBET
Browse files

Refactoring: reassign some classes from beat.core

parent 14bb7f2d
This diff is collapsed.
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.backend.python module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
"""Validation of databases"""
import os
import sys
# import collections
import six
import simplejson
from . import loader
# from . import dataformat
# from . import hash
# from . import utils
# from . import prototypes
class View(object):
'''A special loader class for database views, with specialized methods
Parameters:
db_name (str): The full name of the database object for this view
module (module): The preloaded module containing the database views as
returned by :py:func:`beat.core.loader.load_module`.
prefix (str, path): The prefix path for the current installation
root_folder (str, path): The path pointing to the root folder of this
database
exc (class): The class to use as base exception when translating the
exception from the user code. Read the documention of :py:func:`run`
for more details.
*args: Constructor parameters for the database view. Normally, none.
**kwargs: Constructor parameters for the database view. Normally, none.
'''
def __init__(self, module, definition, prefix, root_folder, exc=None,
*args, **kwargs):
try:
class_ = getattr(module, definition['view'])
except Exception as e:
if exc is not None:
type, value, traceback = sys.exc_info()
six.reraise(exc, exc(value), traceback)
else:
raise #just re-raise the user exception
self.obj = loader.run(class_, '__new__', exc, *args, **kwargs)
self.ready = False
self.prefix = prefix
self.root_folder = root_folder
self.definition = definition
self.exc = exc or RuntimeError
self.outputs = None
def prepare_outputs(self):
'''Prepares the outputs of the dataset'''
from .outputs import Output, OutputList
from .data import MemoryDataSink
from .dataformat import DataFormat
# create the stock outputs for this dataset, so data is dumped
# on a in-memory sink
self.outputs = OutputList()
for out_name, out_format in self.definition.get('outputs', {}).items():
data_sink = MemoryDataSink()
data_sink.dataformat = DataFormat(self.prefix, out_format)
data_sink.setup([])
self.outputs.add(Output(out_name, data_sink, dataset_output=True))
def setup(self, *args, **kwargs):
'''Sets up the view'''
kwargs.setdefault('root_folder', self.root_folder)
kwargs.setdefault('parameters', self.definition.get('parameters', {}))
if 'outputs' not in kwargs:
kwargs['outputs'] = self.outputs
else:
self.outputs = kwargs['outputs'] #record outputs nevertheless
self.ready = loader.run(self.obj, 'setup', self.exc, *args, **kwargs)
if not self.ready:
raise self.exc("unknow setup failure")
return self.ready
def input_group(self, name='default', exclude_outputs=[]):
'''A memory-source input group matching the outputs from the view'''
if not self.ready:
raise self.exc("database view not yet setup")
from .data import MemoryDataSource
from .outputs import SynchronizationListener
from .inputs import Input, InputGroup
# Setup the inputs
synchronization_listener = SynchronizationListener()
input_group = InputGroup(name,
synchronization_listener=synchronization_listener,
restricted_access=False)
for output in self.outputs:
if output.name in exclude_outputs: continue
data_source = MemoryDataSource(self.done, next_callback=self.next)
output.data_sink.data_sources.append(data_source)
input_group.add(Input(output.name,
output.data_sink.dataformat, data_source))
return input_group
def done(self, *args, **kwargs):
'''Checks if the view is done'''
if not self.ready:
raise self.exc("database view not yet setup")
return loader.run(self.obj, 'done', self.exc, *args, **kwargs)
def next(self, *args, **kwargs):
'''Runs through the next data chunk'''
if not self.ready:
raise self.exc("database view not yet setup")
return loader.run(self.obj, 'next', self.exc, *args, **kwargs)
def __getattr__(self, key):
'''Returns an attribute of the view - only called at last resort'''
return getattr(self.obj, key)
class Database(object):
"""Databases define the start point of the dataflow in an experiment.
Parameters:
prefix (str): Establishes the prefix of your installation.
name (str): The fully qualified database name (e.g. ``db/1``)
dataformat_cache (dict, optional): A dictionary mapping dataformat names
to loaded dataformats. This parameter is optional and, if passed, may
greatly speed-up database loading times as dataformats that are already
loaded may be re-used. If you use this parameter, you must guarantee
that the cache is refreshed as appropriate in case the underlying
dataformats change.
Attributes:
name (str): The full, valid name of this database
data (dict): The original data for this database, as loaded by our JSON
decoder.
"""
def __init__(self, prefix, name, dataformat_cache=None):
self._name = None
self.prefix = prefix
self.dataformats = {} # preloaded dataformats
self.data = None
# if the user has not provided a cache, still use one for performance
dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
self._load(name, dataformat_cache)
def _load(self, data, dataformat_cache):
"""Loads the database"""
self._name = data
json_path = os.path.join(prefix, 'databases', name + '.json')
with open(json_path, 'rb') as f: self.data = simplejson.load(f)
@property
def name(self):
"""Returns the name of this object
"""
return self._name or '__unnamed_database__'
@property
def schema_version(self):
"""Returns the schema version"""
return self.data.get('schema_version', 1)
@property
def protocols(self):
"""The declaration of all the protocols of the database"""
data = self.data['protocols']
return dict(zip([k['name'] for k in data], data))
def protocol(self, name):
"""The declaration of a specific protocol in the database"""
return self.protocols[name]
@property
def protocol_names(self):
"""Names of protocols declared for this database"""
data = self.data['protocols']
return [k['name'] for k in data]
def sets(self, protocol):
"""The declaration of a specific set in the database protocol"""
data = self.protocol(protocol)['sets']
return dict(zip([k['name'] for k in data], data))
def set(self, protocol, name):
"""The declaration of all the protocols of the database"""
return self.sets(protocol)[name]
def set_names(self, protocol):
"""The names of sets in a given protocol for this database"""
data = self.protocol(protocol)['sets']
return [k['name'] for k in data]
def view(self, protocol, name, exc=None):
"""Returns the database view, given the protocol and the set name
Parameters:
protocol (str): The name of the protocol where to retrieve the view from
name (str): The name of the set in the protocol where to retrieve the
view from
exc (class): If passed, must be a valid exception class that will be
used to report errors in the read-out of this database's view.
Returns:
The database view, which will be constructed, but not setup. You
**must** set it up before using methods ``done`` or ``next``.
"""
if not self._name:
exc = exc or RuntimeError
raise exc("database has no name")
if not self.valid:
message = "cannot load view for set `%s' of protocol `%s' " \
"from invalid database (%s)" % (protocol, name, self.name)
if exc: raise exc(message)
raise RuntimeError(message)
# loads the module only once through the lifetime of the database object
try:
if not hasattr(self, '_module'):
self._module = loader.load_module(self.name.replace(os.sep, '_'),
self.storage.code.path, {})
except Exception as e:
if exc is not None:
type, value, traceback = sys.exc_info()
six.reraise(exc, exc(value), traceback)
else:
raise #just re-raise the user exception
return View(self._module, self.set(protocol, name), self.prefix,
self.data['root_folder'], exc)
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.backend.python module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
'''Execution utilities'''
import os
import sys
import glob
import errno
import tempfile
import subprocess
import logging
logger = logging.getLogger(__name__)
import simplejson
# from . import schema
from . import database
from . import inputs
from . import outputs
from . import data
from . import message_handler
class DBExecutor(object):
"""Executor specialised in database views
Parameters:
prefix (str): Establishes the prefix of your installation.
data (dict, str): The piece of data representing the block to be executed.
It must validate against the schema defined for execution blocks. If a
string is passed, it is supposed to be a fully qualified absolute path to
a JSON file containing the block execution information.
dataformat_cache (dict, optional): A dictionary mapping dataformat names to
loaded dataformats. This parameter is optional and, if passed, may
greatly speed-up database loading times as dataformats that are already
loaded may be re-used. If you use this parameter, you must guarantee that
the cache is refreshed as appropriate in case the underlying dataformats
change.
database_cache (dict, optional): A dictionary mapping database names to
loaded databases. This parameter is optional and, if passed, may
greatly speed-up database loading times as databases that are already
loaded may be re-used. If you use this parameter, you must guarantee that
the cache is refreshed as appropriate in case the underlying databases
change.
Attributes:
errors (list): A list containing errors found while loading this execution
block.
data (dict): The original data for this executor, as loaded by our JSON
decoder.
databases (dict): A dictionary in which keys are strings with database
names and values are :py:class:`database.Database`, representing the
databases required for running this block. The dictionary may be empty
in case all inputs are taken from the file cache.
views (dict): A dictionary in which the keys are tuples pointing to the
``(<database-name>, <protocol>, <set>)`` and the value is a setup view
for that particular combination of details. The dictionary may be empty
in case all inputs are taken from the file cache.
input_list (beat.core.inputs.InputList): A list of inputs that will be
served to the algorithm.
data_sources (list): A list with all data-sources created by our execution
loader.
"""
def __init__(self, prefix, data, dataformat_cache=None, database_cache=None):
self.prefix = prefix
# some attributes
self.databases = {}
self.views = {}
self.input_list = None
self.data_sources = []
self.handler = None
self.errors = []
self.data = None
# temporary caches, if the user has not set them, for performance
database_cache = database_cache if database_cache is not None else {}
self.dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
self._load(data, self.dataformat_cache, database_cache)
def _load(self, data, dataformat_cache, database_cache):
"""Loads the block execution information"""
# reset
self.data = None
self.errors = []
self.databases = {}
self.views = {}
self.input_list = None
self.data_sources = []
if not isinstance(data, dict): #user has passed a file name
if not os.path.exists(data):
self.errors.append('File not found: %s' % data)
return
with open(data) as f:
self.data = simplejson.load(f)
else:
self.data = data
# this runs basic validation, including JSON loading if required
# self.data, self.errors = schema.validate('execution', data)
# if self.errors: return #don't proceed with the rest of validation
# load databases
for name, details in self.data['inputs'].items():
if 'database' in details:
if details['database'] not in self.databases:
if details['database'] in database_cache: #reuse
db = database_cache[details['database']]
else: #load it
db = database.Database(self.prefix, details['database'],
dataformat_cache)
database_cache[db.name] = db
self.databases[details['database']] = db
if not db.valid:
self.errors += db.errors
continue
if not db.valid:
# do not add errors again
continue
# create and load the required views
key = (details['database'], details['protocol'], details['set'])
if key not in self.views:
view = self.databases[details['database']].view(details['protocol'],
details['set'])
if details['channel'] == self.data['channel']: #synchronized
start_index, end_index = self.data.get('range', (None, None))
else:
start_index, end_index = (None, None)
view.prepare_outputs()
self.views[key] = (view, start_index, end_index)
def __enter__(self):
"""Prepares inputs and outputs for the processing task
Raises:
IOError: in case something cannot be properly setup
"""
self._prepare_inputs()
# The setup() of a database view may call isConnected() on an input
# to set the index at the right location when parallelization is enabled.
# This is why setup() should be called after initialized the inputs.
for key, (view, start_index, end_index) in self.views.items():
if (start_index is None) and (end_index is None):
status = view.setup()
else:
status = view.setup(force_start_index=start_index,
force_end_index=end_index)
if not status:
raise RuntimeError("Could not setup database view `%s'" % key)
return self
def __exit__(self, exc_type, exc_value, traceback):
"""Closes all sinks and disconnects inputs and outputs
"""
self.input_list = None
self.data_sources = []
def _prepare_inputs(self):
"""Prepares all input required by the execution."""
self.input_list = inputs.InputList()
# This is used for parallelization purposes
start_index, end_index = self.data.get('range', (None, None))
for name, details in self.data['inputs'].items():
if 'database' in details: #it is a dataset input
view_key = (details['database'], details['protocol'], details['set'])
view = self.views[view_key][0]
data_source = data.MemoryDataSource(view.done, next_callback=view.next)
self.data_sources.append(data_source)
output = view.outputs[details['output']]
# if it's a synchronized channel, makes the output start at the right
# index, otherwise, it gets lost
if start_index is not None and \
details['channel'] == self.data['channel']:
output.last_written_data_index = start_index - 1
output.data_sink.data_sources.append(data_source)
# Synchronization bits
group = self.input_list.group(details['channel'])
if group is None:
group = inputs.InputGroup(
details['channel'],
synchronization_listener=outputs.SynchronizationListener(),
restricted_access=(details['channel'] == self.data['channel'])
)
self.input_list.add(group)
input_db = self.databases[details['database']]
input_dataformat_name = input_db.set(details['protocol'], details['set'])['outputs'][details['output']]
group.add(inputs.Input(name, self.dataformat_cache[input_dataformat_name], data_source))
def process(self, zmq_context, zmq_socket):
self.handler = message_handler.MessageHandler(self.input_list, zmq_context, zmq_socket)
self.handler.start()
@property
def valid(self):
"""A boolean that indicates if this executor is valid or not"""
return not bool(self.errors)
def wait(self):
self.handler.join()
self.handler = None
def __str__(self):
return simplejson.dumps(self.data, indent=4)
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ #