Commit a00a9e77 authored by Philip ABBET's avatar Philip ABBET
Browse files

Refactoring: one InputGroup for both local and remote inputs

parent f6b6db58
......@@ -102,8 +102,7 @@ class Executor(object):
for name, channel in self.data['inputs'].items():
group = self.input_list.group(channel)
if group is None:
group = inputs.RemoteInputGroup(channel, (channel == main_channel),
socket=self.socket)
group = inputs.InputGroup(channel, restricted_access=(channel == main_channel))
self.input_list.add(group)
thisformat = self.algorithm.dataformats[self.algorithm.input_map[name]]
group.add(inputs.RemoteInput(name, thisformat, self.socket))
......
......@@ -112,145 +112,6 @@ class Input:
#----------------------------------------------------------
class InputGroup:
"""Represents a group of inputs synchronized together
A group implementing this interface is provided to the algorithms (see
:py:class:`beat.backend.python.inputs.InputList`).
See :py:class:`beat.core.inputs.Input`
Example:
.. code-block:: python
inputs = InputList()
print inputs['labels'].data_format
for index in range(0, len(inputs)):
print inputs[index].data_format
for input in inputs:
print input.data_format
for input in inputs[0:2]:
print input.data_format
Parameters:
channel (str): Name of the data channel of the group
synchronization_listener (beat.core.outputs.SynchronizationListener):
Synchronization listener to use
restricted_access (bool): Indicates if the algorithm can freely use the
inputs
Atttributes:
data_index (int): Index of the last block of data received on the inputs
(see the section *Inputs synchronization* of the User's Guide)
data_index_end (int): End index of the last block of data received on the
inputs (see the section *Inputs synchronization* of the User's Guide)
channel (str): Name of the data channel of the group
synchronization_listener (beat.core.outputs.SynchronizationListener):
Synchronization listener used
"""
def __init__(self, channel, synchronization_listener=None,
restricted_access=True):
self._inputs = []
self.data_index = -1
self.data_index_end = -1
self.channel = str(channel)
self.synchronization_listener = synchronization_listener
self.restricted_access = restricted_access
def __getitem__(self, index):
if isinstance(index, six.string_types):
try:
return [x for x in self._inputs if x.name == index][0]
except:
pass
elif isinstance(index, int):
if index < len(self._inputs):
return self._inputs[index]
return None
def __iter__(self):
for k in self._inputs: yield k
def __len__(self):
return len(self._inputs)
def add(self, input):
"""Add an input to the group
Parameters:
input (beat.core.inputs.Input): The input to add
"""
input.group = self
self._inputs.append(input)
def hasMoreData(self):
"""Indicates if there is more data to process in the group"""
return bool([x for x in self._inputs if x.hasMoreData()])
def next(self):
"""Retrieve the next block of data on all the inputs"""
# Only for groups not managed by the platform
if self.restricted_access: raise RuntimeError('Not authorized')
# Only retrieve new data on the inputs where the current data expire first
lower_end_index = reduce(lambda x, y: min(x, y.data_index_end),
self._inputs[1:], self._inputs[0].data_index_end)
inputs_to_update = [x for x in self._inputs \
if x.data_index_end == lower_end_index]
inputs_up_to_date = [x for x in self._inputs if x not in inputs_to_update]
for input in inputs_to_update:
input.next()
input.data_same_as_previous = False
for input in inputs_up_to_date:
input.data_same_as_previous = True
# Compute the group's start and end indices
self.data_index = reduce(lambda x, y: max(x, y.data_index),
self._inputs[1:], self._inputs[0].data_index)
self.data_index_end = reduce(lambda x, y: min(x, y.data_index_end),
self._inputs[1:], self._inputs[0].data_index_end)
# Inform the synchronisation listener
if self.synchronization_listener is not None:
self.synchronization_listener.onIntervalChanged(self.data_index,
self.data_index_end)
#----------------------------------------------------------
class RemoteException(Exception):
def __init__(self, kind, message):
......@@ -396,14 +257,18 @@ class RemoteInput:
#----------------------------------------------------------
class RemoteInputGroup:
"""Allows to access a group of inputs synchronized together, via a socket.
class InputGroup:
"""Represents a group of inputs synchronized together
The inputs can be either "local" ones (reading data from the cache) or
"remote" ones (using a socket to communicate with a database view output
located inside a docker container).
The other end of the socket must be managed by a message handler (see
:py:class:`beat.backend.python.message_handler.MessageHandler`)
A group implementing this interface is provided to the algorithms (see
:py:class:`beat.core.inputs.InputList`).
:py:class:`beat.backend.python.inputs.InputList`).
See :py:class:`beat.core.inputs.Input`
......@@ -429,20 +294,39 @@ class RemoteInputGroup:
channel (str): Name of the data channel of the group
synchronization_listener (beat.core.outputs.SynchronizationListener):
Synchronization listener to use
restricted_access (bool): Indicates if the algorithm can freely use the
inputs
socket (object): A 0MQ socket for writing the data to the server process
Attributes:
data_index (int): Index of the last block of data received on the inputs
(see the section *Inputs synchronization* of the User's Guide)
data_index_end (int): End index of the last block of data received on the
inputs (see the section *Inputs synchronization* of the User's Guide)
channel (str): Name of the data channel of the group
synchronization_listener (beat.core.outputs.SynchronizationListener):
Synchronization listener used
"""
def __init__(self, channel, restricted_access, socket):
def __init__(self, channel, synchronization_listener=None,
restricted_access=True):
self._inputs = []
self.channel = str(channel)
self.restricted_access = restricted_access
self.socket = socket
self.comm_time = 0.
self._inputs = []
self.data_index = -1
self.data_index_end = -1
self.channel = str(channel)
self.synchronization_listener = synchronization_listener
self.restricted_access = restricted_access
self.socket = None
self.comm_time = 0.
def __getitem__(self, index):
......@@ -471,16 +355,38 @@ class RemoteInputGroup:
Parameters:
input (beat.backend.python.inputs.Input): The input to add
input (beat.backend.python.inputs.Input or beat.backend.python.inputs.RemoteInput):
The input to add
"""
if isinstance(input, RemoteInput) and (self.socket is None):
self.socket = input.socket
input.group = self
self._inputs.append(input)
def localInputs(self):
for k in [ x for x in self._inputs if isinstance(x, Input) ]: yield k
def remoteInputs(self):
for k in [ x for x in self._inputs if isinstance(x, RemoteInput) ]: yield k
def hasMoreData(self):
"""Indicates if there is more data to process in the group"""
# First process the local inputs
res = bool([x for x in self.localInputs() if x.hasMoreData()])
if res:
return True
# Next process the remote inputs
if self.socket is None:
return False
logger.debug('send: (hmd) has-more-data %s', self.channel)
_start = time.time()
......@@ -502,38 +408,69 @@ class RemoteInputGroup:
def next(self):
"""Retrieve the next block of data on all the inputs"""
logger.debug('send: (nxt) next %s', self.channel)
self.socket.send('nxt', zmq.SNDMORE)
self.socket.send(self.channel)
# Only for groups not managed by the platform
if self.restricted_access:
raise RuntimeError('Not authorized')
# read all incomming data
_start = time.time()
more = True
parts = []
while more:
parts.append(self.socket.recv())
if parts[-1] == 'err':
self.comm_time += time.time() - _start
process_error(self.socket)
# Only retrieve new data on the inputs where the current data expire first
lower_end_index = reduce(lambda x, y: min(x, y.data_index_end),
self._inputs[1:], self._inputs[0].data_index_end)
inputs_to_update = [x for x in self._inputs \
if x.data_index_end == lower_end_index]
inputs_up_to_date = [x for x in self._inputs if x not in inputs_to_update]
more = self.socket.getsockopt(zmq.RCVMORE)
n = int(parts.pop(0))
logger.debug('recv: %d (inputs)', n)
for k in range(n):
name = parts.pop(0)
logger.debug('recv: %s (data follows)', name)
inpt = self[name]
if inpt is None:
raise RuntimeError("Could not find input `%s' at input group for " \
"channel `%s' while performing `next' operation on this group " \
"(current reading position is %d/%d)" % \
(name, self.channel, k, n))
inpt.data_index = int(parts.pop(0))
inpt.data_index_end = int(parts.pop(0))
inpt.unpack(parts.pop(0))
inpt.nb_data_blocks_read += 1
self.comm_time += time.time() - _start
for input in [ x for x in inputs_to_update if isinstance(x, Input) ]:
input.next()
input.data_same_as_previous = False
remote_inputs_to_update = list([ x for x in inputs_to_update if isinstance(x, RemoteInput) ])
if len(remote_inputs_to_update) > 0:
logger.debug('send: (nxt) next %s', self.channel)
self.socket.send('nxt', zmq.SNDMORE)
self.socket.send(self.channel)
# read all incomming data
_start = time.time()
more = True
parts = []
while more:
parts.append(self.socket.recv())
if parts[-1] == 'err':
self.comm_time += time.time() - _start
process_error(self.socket)
more = self.socket.getsockopt(zmq.RCVMORE)
n = int(parts.pop(0))
logger.debug('recv: %d (inputs)', n)
for k in range(n):
name = parts.pop(0)
logger.debug('recv: %s (data follows)', name)
inpt = self[name]
if inpt is None:
raise RuntimeError("Could not find input `%s' at input group for " \
"channel `%s' while performing `next' operation on this group " \
"(current reading position is %d/%d)" % \
(name, self.channel, k, n))
inpt.data_index = int(parts.pop(0))
inpt.data_index_end = int(parts.pop(0))
inpt.unpack(parts.pop(0))
inpt.nb_data_blocks_read += 1
self.comm_time += time.time() - _start
for input in inputs_up_to_date:
input.data_same_as_previous = True
# Compute the group's start and end indices
self.data_index = reduce(lambda x, y: max(x, y.data_index),
self._inputs[1:], self._inputs[0].data_index)
self.data_index_end = reduce(lambda x, y: min(x, y.data_index_end),
self._inputs[1:], self._inputs[0].data_index_end)
# Inform the synchronisation listener
if self.synchronization_listener is not None:
self.synchronization_listener.onIntervalChanged(self.data_index,
self.data_index_end)
#----------------------------------------------------------
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment