Commit 1819d452 authored by Philip ABBET's avatar Philip ABBET

Refactoring: No remote input / output anymore

parent 29e6ab5e
......@@ -680,7 +680,7 @@ class RemoteDataSource(DataSource):
answer = self.socket.recv()
if answer == 'err':
self.read_duration += time.time() - _start
self.read_duration += time.time() - t1
kind = self.socket.recv()
message = self.socket.recv()
raise RemoteException(kind, message)
......@@ -710,36 +710,6 @@ class RemoteDataSource(DataSource):
#----------------------------------------------------------
class LegacyDataSource(object):
"""Interface of all the Data Sources
Data Sources are used to provides data to the inputs of an algorithm.
"""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def next(self, load=True):
"""Retrieves the next block of data
Returns:
A tuple (*data*, *start_index*, *end_index*)
"""
pass
@abc.abstractmethod
def hasMoreData(self):
"""Indicates if there is more data to process on some of the inputs"""
pass
#----------------------------------------------------------
class DataSink(object):
"""Interface of all the Data Sinks
......@@ -780,7 +750,7 @@ class StdoutDataSink(DataSink):
"""Data Sink that prints informations about the written data on stdout
Note: The written data is lost! Use ii for debugging purposes
Note: The written data is lost! Use this class for debugging purposes
"""
def __init__(self):
......@@ -824,93 +794,6 @@ class StdoutDataSink(DataSink):
#----------------------------------------------------------
class CachedLegacyDataSource(LegacyDataSource):
"""Data Source that load data from the Cache"""
def __init__(self):
self.cached_file = None
self.dataformat = None
self.next_data_index = 0
def setup(self, filename, prefix, force_start_index=None, force_end_index=None,
unpack=True):
"""Configures the data source
Parameters:
filename (str): Name of the file to read the data from
prefix (str, path): Path to the prefix where the dataformats are stored.
force_start_index (int): The starting index (if not set or set to
``None``, the default, read data from the begin of file)
force_end_index (int): The end index (if not set or set to ``None``, the
default, reads the data until the end)
unpack (bool): Indicates if the data must be unpacked or not
Returns:
``True``, if successful, or ``False`` otherwise.
"""
self.cached_file = CachedDataSource()
if self.cached_file.setup(filename, prefix, start_index=force_start_index,
end_index=force_end_index, unpack=unpack):
self.dataformat = self.cached_file.dataformat
return True
return False
def close(self):
"""Closes the data source"""
if self.cached_file is not None:
self.cached_file.close()
self.cached_file = None
def __del__(self):
"""Makes sure the files are close when the object is deleted"""
self.close()
def next(self):
"""Retrieve the next block of data
Returns:
A tuple (data, start_index, end_index)
"""
if self.next_data_index >= len(self.cached_file):
return (None, None, None)
result = self.cached_file[self.next_data_index]
self.next_data_index += 1
return result
def hasMoreData(self):
"""Indicates if there is more data to process on some of the inputs"""
return (self.next_data_index < len(self.cached_file))
def statistics(self):
"""Return the statistics about the number of bytes read from the cache"""
return self.cached_file.statistics()
#----------------------------------------------------------
class CachedDataSink(DataSink):
"""Data Sink that save data in the Cache
......@@ -1093,92 +976,6 @@ class CachedDataSink(DataSink):
#----------------------------------------------------------
class MemoryLegacyDataSource(LegacyDataSource):
"""Interface of all the Data Sources
Data Sources are used to provides data to the inputs of an algorithm.
"""
def __init__(self, done_callback, next_callback=None, index=None):
self.data = []
self._done_callback = done_callback
self._next_callback = next_callback
self._last_data_index = -1
def add(self, data, start_data_index, end_data_index):
self.data.append((data, start_data_index, end_data_index))
self._last_data_index = end_data_index
def next(self):
"""Retrieves the next block of data
:return: A tuple (*data*, *start_index*, *end_index*)
"""
if (len(self.data) == 0) and (self._next_callback is not None):
if not(self._done_callback(self._last_data_index)):
self._next_callback()
if len(self.data) == 0:
return (None, None, None)
return self.data.pop(0)
def hasMoreData(self):
if len(self.data) != 0:
return True
return not self._done_callback(self._last_data_index)
def statistics(self):
"""Return the statistics about the number of bytes read from the cache"""
return (0, 0)
#----------------------------------------------------------
class MemoryDataSink(DataSink):
"""Data Sink that directly transmit data to associated MemoryLegacyDataSource
objects.
"""
def __init__(self):
self.data_sources = None
def setup(self, data_sources):
"""Configure the data sink
:param list data_sources: The MemoryLegacyDataSource objects to use
"""
self.data_sources = data_sources
def write(self, data, start_data_index, end_data_index):
"""Write a block of data
Parameters:
data (beat.core.baseformat.baseformat) The block of data to write
start_data_index (int): Start index of the written data
end_data_index (int): End index of the written data
"""
for data_source in self.data_sources:
data_source.add(data, start_data_index, end_data_index)
def isConnected(self):
return len(self.data_sources) > 0
#----------------------------------------------------------
def load_data_index(cache_prefix, hash_path):
"""Loads a cached-data index if it exists. Returns empty otherwise.
......
......@@ -100,7 +100,7 @@ class DataView(object):
input_data_indices.append( (current_start, self.data_index_end) )
self.infos[input_name] = dict(
cached_file = infos['cached_file'],
data_source = infos['data_source'],
data_indices = input_data_indices,
data = None,
start_index = -1,
......@@ -132,7 +132,7 @@ class DataView(object):
for input_name, infos in self.infos.items():
if (indices[0] < infos['start_index']) or (infos['end_index'] < indices[0]):
(infos['data'], infos['start_index'], infos['end_index']) = \
infos['cached_file'].getAtDataIndex(indices[0])
infos['data_source'].getAtDataIndex(indices[0])
result[input_name] = infos['data']
......@@ -189,10 +189,10 @@ class DataLoader(object):
self.data_index_end = -1 # Bigger index across all inputs
def add(self, input_name, cached_file):
def add(self, input_name, data_source):
self.infos[input_name] = dict(
cached_file = cached_file,
data_indices = cached_file.data_indices(),
data_source = data_source,
data_indices = data_source.data_indices(),
data = None,
start_index = -1,
end_index = -1,
......@@ -247,7 +247,7 @@ class DataLoader(object):
for input_name, infos in self.infos.items():
if (indices[0] < infos['start_index']) or (infos['end_index'] < indices[0]):
(infos['data'], infos['start_index'], infos['end_index']) = \
infos['cached_file'].getAtDataIndex(indices[0])
infos['data_source'].getAtDataIndex(indices[0])
result[input_name] = infos['data']
......
......@@ -101,21 +101,15 @@ class Executor(object):
if self.algorithm.type == Algorithm.LEGACY:
# Loads algorithm inputs
if self.data['proxy_mode']:
cache_access = AccessMode.REMOTE
else:
cache_access = AccessMode.LOCAL
(self.input_list, self.data_loaders, _) = create_inputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root,
cache_access=cache_access, db_access=AccessMode.REMOTE,
cache_access=AccessMode.LOCAL, db_access=AccessMode.REMOTE,
socket=self.socket
)
# Loads algorithm outputs
(self.output_list, _) = create_outputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root, self.input_list,
cache_access=cache_access, socket=self.socket
self.data, self.algorithm, self.prefix, cache_root, self.input_list
)
else:
......@@ -126,8 +120,7 @@ class Executor(object):
# Loads algorithm outputs
(self.output_list, _) = create_outputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root, self.input_list,
cache_access=AccessMode.LOCAL
self.data, self.algorithm, self.prefix, cache_root, self.input_list
)
......
This diff is collapsed.
......@@ -42,8 +42,9 @@ from .data import RemoteException
#----------------------------------------------------------
class BaseInput(object):
"""Base class for all the kinds of input of a processing block
class Input(object):
"""Represents an input of a processing block that receive data from a
(legacy) data source
A list of those inputs must be provided to the algorithms (see
:py:class:`beat.backend.python.inputs.InputList`)
......@@ -80,16 +81,18 @@ class BaseInput(object):
"""
def __init__(self, name, data_format):
def __init__(self, name, data_format, data_source):
self.group = None
self.name = str(name)
self.data = None
self.data_index = -1
self.data_index_end = -1
self.next_data_index = 0
self.data_same_as_previous = True
self.data_format = data_format
self.nb_data_blocks_read = 0
self.data_source = data_source
def isDataUnitDone(self):
......@@ -104,7 +107,7 @@ class BaseInput(object):
def hasMoreData(self):
"""Indicates if there is more data to process on the input"""
raise NotImplemented()
return (self.next_data_index < len(self.data_source))
def hasDataChanged(self):
......@@ -114,184 +117,24 @@ class BaseInput(object):
return not self.data_same_as_previous
def next(self):
"""Retrieves the next block of data"""
raise NotImplemented()
#----------------------------------------------------------
class Input(BaseInput):
"""Represents an input of a processing block that receive data from a
data source
A list of those inputs must be provided to the algorithms (see
:py:class:`beat.backend.python.inputs.InputList`)
Parameters:
name (str): Name of the input
data_format (str): Data format accepted by the input
data_source (beat.core.platform.data.LegacyDataSource): Source of data to be used
by the input
Attributes:
data_source (beat.core.data.LegacyDataSource): Source of data used by the output
"""
def __init__(self, name, data_format, data_source):
super(Input, self).__init__(name, data_format)
self.data_source = data_source
def hasMoreData(self):
"""Indicates if there is more data to process on the input"""
return self.data_source.hasMoreData()
def next(self):
"""Retrieves the next block of data"""
if self.group.restricted_access:
raise RuntimeError('Not authorized')
(self.data, self.data_index, self.data_index_end) = self.data_source.next()
if self.data is None:
if self.next_data_index >= len(self.data_source):
message = "User algorithm asked for more data for channel " \
"`%s' on input `%s', but it is over (no more data). This " \
"normally indicates a programming error on the user " \
"side." % (self.group.channel, self.name)
raise RuntimeError(message)
self.data_same_as_previous = False
self.nb_data_blocks_read += 1
#----------------------------------------------------------
def process_error(socket):
kind = socket.recv()
message = socket.recv()
raise RemoteException(kind, message)
#----------------------------------------------------------
class RemoteInput(BaseInput):
"""Allows to access the input of a processing block, via a socket.
The other end of the socket must be managed by a message handler (see
:py:class:`beat.backend.python.message_handler.MessageHandler`)
A list of those inputs must be provided to the algorithms (see
:py:class:`beat.backend.python.inputs.InputList`)
Parameters:
name (str): Name of the input
data_format (object): An object with the preloaded data format for this
input (see :py:class:`beat.backend.python.dataformat.DataFormat`).
socket (object): A 0MQ socket for writing the data to the server process
Attributes:
group (beat.backend.python.inputs.InputGroup): Group containing this input
data (beat.core.baseformat.baseformat): The last block of data received on
the input
"""
def __init__(self, name, data_format, socket, unpack=True):
super(RemoteInput, self).__init__(name, data_format)
self.socket = socket
self.comm_time = 0.0 # Total time spent on communication
self._unpack = unpack
self._has_more_data = None # To avoid repetitive requests
def hasMoreData(self):
"""Indicates if there is more data to process on the input"""
if self._has_more_data is None:
logger.debug('send: (hmd) has-more-data %s %s', self.group.channel, self.name)
_start = time.time()
self.socket.send('hmd', zmq.SNDMORE)
self.socket.send(self.group.channel, zmq.SNDMORE)
self.socket.send(self.name)
answer = self.socket.recv()
self.comm_time += time.time() - _start
logger.debug('recv: %s', answer)
if answer == 'err':
process_error(self.socket)
self._has_more_data = (answer == 'tru')
return self._has_more_data
def next(self):
"""Retrieves the next block of data"""
logger.debug('send: (nxt) next %s %s', self.group.channel, self.name)
_start = time.time()
self.socket.send('nxt', zmq.SNDMORE)
self.socket.send(self.group.channel, zmq.SNDMORE)
self.socket.send(self.name)
answer = self.socket.recv()
if answer == 'err':
self.comm_time += time.time() - _start
process_error(self.socket)
self.data_index = int(answer)
self.data_index_end = int(self.socket.recv())
self.unpack(self.socket.recv())
self.comm_time += time.time() - _start
self.nb_data_blocks_read += 1
(self.data, self.data_index, self.data_index_end) = self.data_source[self.next_data_index]
self.data_same_as_previous = False
self._has_more_data = None
def unpack(self, packed):
"""Receives data through socket"""
logger.debug('recv: <bin> (size=%d), indexes=(%d, %d)', len(packed),
self.data_index, self.data_index_end)
if self.unpack:
self.data = self.data_format.type()
self.data.unpack(packed)
else:
self.data = packed
self.next_data_index += 1
self.nb_data_blocks_read += 1
#----------------------------------------------------------
......@@ -300,13 +143,6 @@ class RemoteInput(BaseInput):
class InputGroup:
"""Represents a group of inputs synchronized together
The inputs can be either "local" ones (reading data from the cache) or
"remote" ones (using a socket to communicate with a database view output
located inside a docker container).
The other end of the socket must be managed by a message handler (see
:py:class:`beat.backend.python.message_handler.MessageHandler`)
A group implementing this interface is provided to the algorithms (see
:py:class:`beat.backend.python.inputs.InputList`).
......@@ -357,7 +193,7 @@ class InputGroup:
"""
def __init__(self, channel, synchronization_listener=None,
restricted_access=True):
restricted_access=True):
self._inputs = []
self.data_index = -1 # Lower index across all inputs
......@@ -367,8 +203,6 @@ class InputGroup:
self.channel = str(channel)
self.synchronization_listener = synchronization_listener
self.restricted_access = restricted_access
self.socket = None
self.comm_time = 0.
def __getitem__(self, index):
......@@ -402,38 +236,15 @@ class InputGroup:
"""
if isinstance(input, RemoteInput) and (self.socket is None):
self.socket = input.socket
input.group = self
self._inputs.append(input)
def localInputs(self):
for k in [ x for x in self._inputs if isinstance(x, Input) ]: yield k
def remoteInputs(self):
for k in [ x for x in self._inputs if isinstance(x, RemoteInput) ]: yield k
def hasMoreData(self):
"""Indicates if there is more data to process in the group"""
# First process the local inputs
res = bool([x for x in self.localInputs() if x.hasMoreData()])
if res:
return True
return bool([x for x in self._inputs if x.hasMoreData()])
# Next process the remote inputs
if self.socket is None:
return False
for x in self.remoteInputs():
if x.hasMoreData():
return True
return False
def next(self):
......@@ -450,13 +261,9 @@ class InputGroup:
if x.data_index_end == lower_end_index]
inputs_up_to_date = [x for x in self._inputs if x not in inputs_to_update]
for input in [ x for x in inputs_to_update if isinstance(x, Input) ]:
for input in inputs_to_update:
input.next()
remote_inputs_to_update = list([ x for x in inputs_to_update if isinstance(x, RemoteInput) ])
for remote_input in remote_inputs_to_update:
remote_input.next()
for input in inputs_up_to_date:
input.data_same_as_previous = True
......
......@@ -41,8 +41,7 @@ from .data import RemoteException
class MessageHandler(threading.Thread):
'''A 0MQ message handler for our communication with other processes'''
def __init__(self, host_address, inputs=None, outputs=None, data_sources=None,
kill_callback=None):
def __init__(self, host_address, data_sources=None, kill_callback=None):
super(MessageHandler, self).__init__()
......@@ -74,10 +73,7 @@ class MessageHandler(threading.Thread):
self.poller.register(self.socket, zmq.POLLIN)
# Initialisations
self.input_list = inputs
self.output_list = outputs
self.data_sources = data_sources
self.system_error = ''
self.user_error = ''
self.statistics = {}
......@@ -90,18 +86,6 @@ class MessageHandler(threading.Thread):
err = self.error,
)
if self.input_list is not None:
self.callbacks.update(dict(
nxt = self.next,
hmd = self.has_more_data,
))
if self.output_list is not None:
self.callbacks.update(dict(
wrt = self.write,
oic = self.output_is_connected,
))
if self.data_sources is not None:
self.callbacks.update(dict(
ifo = self.infos,
......@@ -247,101 +231,6 @@ class MessageHandler(threading.Thread):
self._acknowledge()
def next(self, channel, name):
"""Syntax: nxt channel name ..."""
logger.debug('recv: nxt %s %s', channel, name)
if self.input_list is None:
message = 'Unexpected message received: nxt %s %s' % (channel, name)
raise RemoteException('sys', message)
channel_group = self.input_list.group(channel)
restricted = channel_group.restricted_access
channel_group.restricted_access = False
input_candidate = self._get_input_candidate(channel, name)
input_candidate.next()
channel_group.restricted_access = restricted
if input_candidate.data is None: #error
message = "User algorithm asked for more data for channel " \
"`%s' on input `%s', but it is over (no more data). This " \
"normally indicates a programming error on the user " \
"side." % (channel, name)
raise RuntimeError(message)
if isinstance(input_candidate.data, baseformat.baseformat):
packed = input_candidate.data.pack()
else:
packed = input_candidate.data
logger.debug('send: <bin> (size=%d), indexes=(%d, %d)', len(packed),
input_candidate.data_index, input_candidate.data_index_end)
self.socket.send('%d' % input_candidate.data_index, zmq.SNDMORE)
self.socket.send('%d' % input_candidate.data_index_end, zmq.SNDMORE)
self.socket.send(packed)
def has_more_data(self, channel, name):
"""Syntax: hmd channel name"""
logger.debug('recv: hmd %s %s', channel, name)