Commit 29e6ab5e authored by Philip ABBET's avatar Philip ABBET

Refactoring: Databases, DBExecutor, MessageHandler, DataSources

parent dce310ab
......@@ -36,6 +36,7 @@ import select
import time
import tempfile
import abc
import zmq
from functools import reduce
from collections import namedtuple
......@@ -51,6 +52,22 @@ from .algorithm import Algorithm
#----------------------------------------------------------
class RemoteException(Exception):
def __init__(self, kind, message):
super(RemoteException, self).__init__()
if kind == 'sys':
self.system_error = message
self.user_error = ''
else:
self.system_error = ''
self.user_error = message
#----------------------------------------------------------
def mixDataIndices(list_of_data_indices):
"""Given a collection of lists of data indices (belonging to separate
but synchronized files/inputs), returns the most granular list of
......@@ -171,20 +188,78 @@ def getAllFilenames(filename, start_index=None, end_index=None):
#----------------------------------------------------------
class CachedFileLoader(object):
class DataSource(object):
"""Base class to load data from some source"""
def __init__(self):
self.infos = []
self.read_duration = 0
self.nb_bytes_read = 0
def close(self):
self.infos = []
def __del__(self):
"""Makes sure all resources are released when the object is deleted"""
self.close()
def __len__(self):
return len(self.infos)
def __iter__(self):
for i in range(0, len(self.infos)):
yield self[i]
def __getitem__(self, index):
raise NotImplemented()
def first_data_index(self):
return self.infos[0].start_index
def last_data_index(self):
return self.infos[-1].end_index
def data_indices(self):
return [ (x.start_index, x.end_index) for x in self.infos ]
def getAtDataIndex(self, data_index):
for index, infos in enumerate(self.infos):
if (infos.start_index <= data_index) and (data_index <= infos.end_index):
return self[index]
return (None, None, None)
def statistics(self):
"""Return the statistics about the number of bytes read"""
return (self.nb_bytes_read, self.read_duration)
#----------------------------------------------------------
class CachedDataSource(DataSource):
"""Utility class to load data from a file in the cache"""
def __init__(self):
super(CachedDataSource, self).__init__()
self.filenames = None
self.encoding = None # Must be 'binary' or 'json'
self.prefix = None
self.dataformat = None
self.infos = []
self.current_file = None
self.current_file_index = None
self.unpack = True
self.read_duration = 0
self.nb_bytes_read = 0
def _readHeader(self, file):
......@@ -232,10 +307,10 @@ class CachedFileLoader(object):
prefix (str, path): Path to the prefix where the dataformats are stored.
force_start_index (int): The starting index (if not set or set to
start_index (int): The starting index (if not set or set to
``None``, the default, read data from the begin of file)
force_end_index (int): The end index (if not set or set to ``None``, the
end_index (int): The end index (if not set or set to ``None``, the
default, reads the data until the end)
unpack (bool): Indicates if the data must be unpacked or not
......@@ -345,19 +420,7 @@ class CachedFileLoader(object):
if self.current_file is not None:
self.current_file.close()
def __del__(self):
"""Makes sure the files are close when the object is deleted"""
self.close()
def __len__(self):
return len(self.infos)
def __iter__(self):
for i in range(0, len(self.infos)):
yield self[i]
super(CachedDataSource, self).close()
def __getitem__(self, index):
......@@ -403,35 +466,251 @@ class CachedFileLoader(object):
return (data, infos.start_index, infos.end_index)
def first_data_index(self):
return self.infos[0].start_index
#----------------------------------------------------------
def last_data_index(self):
return self.infos[-1].end_index
class DatabaseOutputDataSource(DataSource):
"""Utility class to load data from an output of a database view"""
def __init__(self):
super(DatabaseOutputDataSource, self).__init__()
def data_indices(self):
return [ (x.start_index, x.end_index) for x in self.infos ]
self.prefix = None
self.dataformat = None
self.view = None
self.output_name = None
self.pack = True
def getAtDataIndex(self, data_index):
for index, infos in enumerate(self.infos):
if (infos.start_index <= data_index) and (data_index <= infos.end_index):
return self[index]
def setup(self, view, output_name, dataformat_name, prefix, start_index=None,
end_index=None, pack=False):
"""Configures the data source
return (None, None, None)
Parameters:
def statistics(self):
"""Return the statistics about the number of bytes read from the files"""
return (self.nb_bytes_read, self.read_duration)
prefix (str, path): Path to the prefix where the dataformats are stored.
start_index (int): The starting index (if not set or set to
``None``, the default, read data from the begin of file)
end_index (int): The end index (if not set or set to ``None``, the
default, reads the data until the end)
unpack (bool): Indicates if the data must be unpacked or not
Returns:
``True``, if successful, or ``False`` otherwise.
"""
self.prefix = prefix
self.view = view
self.output_name = output_name
self.pack = pack
self.dataformat = DataFormat(self.prefix, dataformat_name)
if not self.dataformat.valid:
raise RuntimeError("the dataformat `%s' is not valid" % dataformat_name)
# Load all the needed infos from all the files
Infos = namedtuple('Infos', ['start_index', 'end_index'])
objects = self.view.objects()
start = None
end = None
previous_value = None
for index, obj in enumerate(objects):
if start is None:
start = index
previous_value = getattr(obj, output_name)
elif getattr(obj, output_name) != previous_value:
end = index - 1
previous_value = None
if ((start_index is None) or (start >= start_index)) and \
((end_index is None) or (end <= end_index)):
self.infos.append(Infos(start_index=start, end_index=end))
start = index
previous_value = getattr(obj, output_name)
end = index
if ((start_index is None) or (start >= start_index)) and \
((end_index is None) or (end <= end_index)):
self.infos.append(Infos(start_index=start, end_index=end))
return True
def __getitem__(self, index):
"""Retrieve a block of data
Returns:
A tuple (data, start_index, end_index)
"""
if (index < 0) or (index >= len(self.infos)):
return (None, None, None)
infos = self.infos[index]
t1 = time.time()
data = self.view.get(self.output_name, infos.start_index)
t2 = time.time()
self.read_duration += t2 - t1
if isinstance(data, dict):
d = self.dataformat.type()
d.from_dict(data, casting='safe', add_defaults=False)
data = d
if self.pack:
data = data.pack()
self.nb_bytes_read += len(data)
return (data, infos.start_index, infos.end_index)
#----------------------------------------------------------
class DataSource(object):
class RemoteDataSource(DataSource):
"""Utility class to load data from a data source accessible via a socket"""
def __init__(self):
super(RemoteDataSource, self).__init__()
self.socket = None
self.input_name = None
self.dataformat = None
self.unpack = True
def setup(self, socket, input_name, dataformat_name, prefix, unpack=True):
"""Configures the data source
Parameters:
socket (socket): The socket to use to access the data.
input_name (str): Name of the input corresponding to the data source.
dataformat_name (str): Name of the data format.
prefix (str, path): Path to the prefix where the dataformats are stored.
unpack (bool): Indicates if the data must be unpacked or not
Returns:
``True``, if successful, or ``False`` otherwise.
"""
self.socket = socket
self.input_name = input_name
self.unpack = unpack
self.dataformat = DataFormat(prefix, dataformat_name)
if not self.dataformat.valid:
raise RuntimeError("the dataformat `%s' is not valid" % dataformat_name)
# Load the needed infos from the socket
Infos = namedtuple('Infos', ['start_index', 'end_index'])
logger.debug('send: (ifo) infos %s', self.input_name)
self.socket.send('ifo', zmq.SNDMORE)
self.socket.send(self.input_name)
answer = self.socket.recv()
logger.debug('recv: %s', answer)
if answer == 'err':
kind = self.socket.recv()
message = self.socket.recv()
raise RemoteException(kind, message)
nb = int(answer)
for i in range(nb):
start = int(self.socket.recv())
end = int(self.socket.recv())
self.infos.append(Infos(start_index=start, end_index=end))
return True
def __getitem__(self, index):
"""Retrieve a block of data
Returns:
A tuple (data, start_index, end_index)
"""
if (index < 0) or (index >= len(self.infos)):
return (None, None, None)
infos = self.infos[index]
logger.debug('send: (get) get %s %d', self.input_name, index)
t1 = time.time()
self.socket.send('get', zmq.SNDMORE)
self.socket.send(self.input_name, zmq.SNDMORE)
self.socket.send('%d' % index)
answer = self.socket.recv()
if answer == 'err':
self.read_duration += time.time() - _start
kind = self.socket.recv()
message = self.socket.recv()
raise RemoteException(kind, message)
start = int(answer)
end = int(self.socket.recv())
packed = self.socket.recv()
t2 = time.time()
logger.debug('recv: <bin> (size=%d), indexes=(%d, %d)', len(packed), start, end)
self.nb_bytes_read += len(packed)
if self.unpack:
data = self.dataformat.type()
data.unpack(packed)
else:
data = packed
self.read_duration += t2 - t1
return (data, infos.start_index, infos.end_index)
#----------------------------------------------------------
class LegacyDataSource(object):
"""Interface of all the Data Sources
......@@ -545,7 +824,7 @@ class StdoutDataSink(DataSink):
#----------------------------------------------------------
class CachedDataSource(DataSource):
class CachedLegacyDataSource(LegacyDataSource):
"""Data Source that load data from the Cache"""
def __init__(self):
......@@ -580,7 +859,7 @@ class CachedDataSource(DataSource):
"""
self.cached_file = CachedFileLoader()
self.cached_file = CachedDataSource()
if self.cached_file.setup(filename, prefix, start_index=force_start_index,
end_index=force_end_index, unpack=unpack):
self.dataformat = self.cached_file.dataformat
......@@ -814,7 +1093,7 @@ class CachedDataSink(DataSink):
#----------------------------------------------------------
class MemoryDataSource(DataSource):
class MemoryLegacyDataSource(LegacyDataSource):
"""Interface of all the Data Sources
......@@ -864,7 +1143,7 @@ class MemoryDataSource(DataSource):
class MemoryDataSink(DataSink):
"""Data Sink that directly transmit data to associated MemoryDataSource
"""Data Sink that directly transmit data to associated MemoryLegacyDataSource
objects.
"""
......@@ -874,7 +1153,7 @@ class MemoryDataSink(DataSink):
def setup(self, data_sources):
"""Configure the data sink
:param list data_sources: The MemoryDataSource objects to use
:param list data_sources: The MemoryLegacyDataSource objects to use
"""
self.data_sources = data_sources
......
......@@ -35,6 +35,7 @@ import six
import simplejson
import itertools
import numpy as np
from collections import namedtuple
from . import loader
from . import utils
......@@ -73,7 +74,7 @@ class Storage(utils.CodeStorage):
#----------------------------------------------------------
class View(object):
class Runner(object):
'''A special loader class for database views, with specialized methods
Parameters:
......@@ -98,9 +99,7 @@ class View(object):
'''
def __init__(self, module, definition, prefix, root_folder, exc=None,
*args, **kwargs):
def __init__(self, module, definition, prefix, root_folder, exc=None):
try:
class_ = getattr(module, definition['view'])
......@@ -109,94 +108,73 @@ class View(object):
type, value, traceback = sys.exc_info()
six.reraise(exc, exc(value), traceback)
else:
raise #just re-raise the user exception
raise # just re-raise the user exception
self.obj = loader.run(class_, '__new__', exc, *args, **kwargs)
self.ready = False
self.prefix = prefix
self.root_folder = root_folder
self.definition = definition
self.exc = exc or RuntimeError
self.outputs = None
self.obj = loader.run(class_, '__new__', exc)
self.ready = False
self.prefix = prefix
self.root_folder = root_folder
self.definition = definition
self.exc = exc or RuntimeError
self.data_sources = None
def prepare_outputs(self):
'''Prepares the outputs of the dataset'''
def index(self, filename):
'''Index the content of the view'''
from .outputs import Output, OutputList
from .data import MemoryDataSink
from .dataformat import DataFormat
parameters = self.definition.get('parameters', {})
# create the stock outputs for this dataset, so data is dumped
# on a in-memory sink
self.outputs = OutputList()
for out_name, out_format in self.definition.get('outputs', {}).items():
data_sink = MemoryDataSink()
data_sink.dataformat = DataFormat(self.prefix, out_format)
data_sink.setup([])
self.outputs.add(Output(out_name, data_sink, dataset_output=True))
objs = loader.run(self.obj, 'index', self.exc, self.root_folder, parameters)
if not isinstance(objs, list):
raise self.exc("index() didn't return a list")
def setup(self, *args, **kwargs):
'''Sets up the view'''
if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))
kwargs.setdefault('root_folder', self.root_folder)
kwargs.setdefault('parameters', self.definition.get('parameters', {}))
with open(filename, 'wb') as f:
simplejson.dump(objs, f)
if 'outputs' not in kwargs:
kwargs['outputs'] = self.outputs
else:
self.outputs = kwargs['outputs'] #record outputs nevertheless
self.ready = loader.run(self.obj, 'setup', self.exc, *args, **kwargs)
def setup(self, filename, start_index=None, end_index=None, pack=True):
'''Sets up the view'''
if not self.ready:
raise self.exc("unknow setup failure")
if self.ready:
return
return self.ready
with open(filename, 'rb') as f:
objs = simplejson.load(f)
Entry = namedtuple('Entry', sorted(objs[0].keys()))
objs = [ Entry(**x) for x in objs ]
def input_group(self, name='default', exclude_outputs=[]):
'''A memory-source input group matching the outputs from the view'''
parameters = self.definition.get('parameters', {})
if not self.ready:
raise self.exc("database view not yet setup")
loader.run(self.obj, 'setup', self.exc, self.root_folder, parameters,
objs, start_index=start_index, end_index=end_index)
from .data import MemoryDataSource
from .outputs import SynchronizationListener
from .inputs import Input, InputGroup
# Setup the inputs
synchronization_listener = SynchronizationListener()
input_group = InputGroup(name,
synchronization_listener=synchronization_listener,
restricted_access=False)
# Create data sources for the outputs
from .data import DatabaseOutputDataSource
from .dataformat import DataFormat
for output in self.outputs:
if output.name in exclude_outputs: continue
data_source = MemoryDataSource(self.done, next_callback=self.next)
output.data_sink.data_sources.append(data_source)
input_group.add(Input(output.name,
output.data_sink.dataformat, data_source))
self.data_sources = {}
for output_name, output_format in self.definition.get('outputs', {}).items():
data_source = DatabaseOutputDataSource()
data_source.setup(self, output_name, output_format, self.prefix,
start_index=start_index, end_index=end_index, pack=pack)
self.data_sources[output_name] = data_source
return input_group
self.ready = True
def done(self, *args, **kwargs):
'''Checks if the view is done'''
def get(self, output, index):
'''Returns the data of the provided output at the provided index'''
if not self.ready:
raise self.exc("database view not yet setup")
return loader.run(self.obj, 'done', self.exc, *args, **kwargs)
raise self.exc("Database view not yet setup")
def next(self, *args, **kwargs):
'''Runs through the next data chunk'''
if not self.ready:
raise self.exc("database view not yet setup")
return loader.run(self.obj, 'next', self.exc, *args, **kwargs)
return loader.run(self.obj, 'get', self.exc, output, index)
def __getattr__(self, key):
......@@ -204,6 +182,10 @@ class View(object):
return getattr(self.obj, key)
def objects(self):
return self.obj.objs
#----------------------------------------------------------
......@@ -368,7 +350,9 @@ class Database(object):
if not self.valid:
message = "cannot load view for set `%s' of protocol `%s' " \
"from invalid database (%s)" % (protocol, name, self.name)
if exc: raise exc(message)
if exc:
raise exc(message)
raise RuntimeError(message)
# loads the module only once through the lifetime of the database object
......@@ -383,8 +367,73 @@ class Database(object):
else:
raise #just re-raise the user exception
return View(self._module, self.set(protocol, name), self.prefix,
self.data['root_folder'], exc)
return Runner(self._module, self.set(protocol, name), self.prefix,
self.data['root_folder'], exc)
#----------------------------------------------------------
class View(object):
def index(self, root_folder, parameters):
"""Returns a list of (named) tuples describing the data provided by the view.
The ordering of values inside the tuples is free, but it is expected
that the list is ordered in a consistent manner (ie. all train images of
person A, then all train images of person B, ...).
For instance, assuming a view providing that kind of data:
----------- ----------- ----------- ----------- ----------- -----------
| image | | image | | image | | image | | image | | image |
----------- ----------- ----------- ----------- ----------- -----------
----------- ----------- ----------- ----------- ----------- -----------
| file_id | | file_id | | file_id | | file_id | | file_id | | file_id |
----------- ----------- ----------- ----------- ----------- -----------
----------------------------------- -----------------------------------
| client_id | | client_id |
----------------------------------- -----------------------------------
a list like the following should be generated:
[
(client_id=1, file_id=1, image=filename1),
(client_id=1, file_id=2, image=filename2),
(client_id=1, file_id=3, image=filename3),
(client_id=2, file_id=4, image=filename4),
(client_id=2, file_id=5, image=filename5),
(client_id=2, file_id=6, image=filename6),
...
]
DO NOT store images, sound files or data loadable from a file in the list!
Store the path of the file to load instead.
"""
raise NotImplementedError
def setup(self, root_folder, parameters, objs, start_index=None, end_index=None):
# Initialisations
self.root_folder = root_folder
self.parameters = parameters
self.objs = objs
# Determine the range of indices that must be provided