#!/usr/bin/env python # vim: set fileencoding=utf-8 : ############################################################################### # # # Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ # # Contact: beat.support@idiap.ch # # # # This file is part of the beat.backend.python module of the BEAT platform. # # # # Commercial License Usage # # Licensees holding valid commercial BEAT licenses may use this file in # # accordance with the terms contained in a written agreement between you # # and Idiap. For further information contact tto@idiap.ch # # # # Alternatively, this file may be used under the terms of the GNU Affero # # Public License version 3 as published by the Free Software and appearing # # in the file LICENSE.AGPL included in the packaging of this file. # # The BEAT platform is distributed in the hope that it will be useful, but # # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # # or FITNESS FOR A PARTICULAR PURPOSE. # # # # You should have received a copy of the GNU Affero Public License along # # with the BEAT platform. If not, see http://www.gnu.org/licenses/. # # # ############################################################################### import time import logging logger = logging.getLogger(__name__) from functools import reduce import six import zmq class Input: """Represents the input of a processing block A list of those inputs must be provided to the algorithms (see :py:class:`beat.backend.python.inputs.InputList`) Parameters: name (str): Name of the input data_format (str): Data format accepted by the input data_source (beat.core.platform.data.DataSource): Source of data to be used by the input Attributes: group (beat.core.inputs.InputGroup): Group containing this input name (str): Name of the input (algorithm-specific) data (beat.core.baseformat.baseformat): The last block of data received on the input data_index (int): Index of the last block of data received on the input (see the section *Inputs synchronization* of the User's Guide) data_index_end (int): End index of the last block of data received on the input (see the section *Inputs synchronization* of the User's Guide) data_format (str): Data format accepted by the input data_source (beat.core.data.DataSource): Source of data used by the output nb_data_blocks_read (int): Number of data blocks read so far """ def __init__(self, name, data_format, data_source): self.group = None self.name = str(name) self.data = None self.data_index = -1 self.data_index_end = -1 self.data_same_as_previous = True self.data_format = data_format self.data_source = data_source self.nb_data_blocks_read = 0 def isDataUnitDone(self): """Indicates if the current data unit will change at the next iteration""" if (self.data_index_end >= 0) and (self.group.data_index_end == -1): return True return (self.data_index_end == self.group.data_index_end) def hasMoreData(self): """Indicates if there is more data to process on the input""" return self.data_source.hasMoreData() def hasDataChanged(self): """Indicates if the current data unit is different than the one at the previous iteration""" return not self.data_same_as_previous def next(self): """Retrieves the next block of data""" if self.group.restricted_access: raise RuntimeError('Not authorized') (self.data, self.data_index, self.data_index_end) = self.data_source.next() if self.data is None: message = "User algorithm asked for more data for channel " \ "`%s' on input `%s', but it is over (no more data). This " \ "normally indicates a programming error on the user " \ "side." % (self.group.channel, self.name) raise RuntimeError(message) self.data_same_as_previous = False self.nb_data_blocks_read += 1 #---------------------------------------------------------- class RemoteException(Exception): def __init__(self, kind, message): super(RemoteException, self).__init__() if kind == 'sys': self.system_error = message self.user_error = '' else: self.system_error = '' self.user_error = message #---------------------------------------------------------- def process_error(socket): kind = socket.recv() message = socket.recv() raise RemoteException(kind, message) #---------------------------------------------------------- class RemoteInput: """Allows to access the input of a processing block, via a socket. The other end of the socket must be managed by a message handler (see :py:class:`beat.backend.python.message_handler.MessageHandler`) A list of those inputs must be provided to the algorithms (see :py:class:`beat.backend.python.inputs.InputList`) Parameters: name (str): Name of the input data_format (object): An object with the preloaded data format for this input (see :py:class:`beat.backend.python.dataformat.DataFormat`). socket (object): A 0MQ socket for writing the data to the server process Attributes: group (beat.backend.python.inputs.InputGroup): Group containing this input data (beat.core.baseformat.baseformat): The last block of data received on the input """ def __init__(self, name, data_format, socket, unpack=True): self.name = str(name) self.data_format = data_format self.socket = socket self.data = None self.data_index = -1 self.data_index_end = -1 self.group = None self.comm_time = 0. #total time spent on communication self.nb_data_blocks_read = 0 self._unpack = unpack def isDataUnitDone(self): """Indicates if the current data unit will change at the next iteration""" logger.debug('send: (idd) is-dataunit-done %s', self.name) _start = time.time() self.socket.send('idd', zmq.SNDMORE) self.socket.send(self.group.channel, zmq.SNDMORE) self.socket.send(self.name) answer = self.socket.recv() self.comm_time += time.time() - _start logger.debug('recv: %s', answer) return answer == 'tru' def hasMoreData(self): """Indicates if there is more data to process on the input""" logger.debug('send: (hmd) has-more-data %s %s', self.group.channel, self.name) _start = time.time() self.socket.send('hmd', zmq.SNDMORE) self.socket.send(self.group.channel, zmq.SNDMORE) self.socket.send(self.name) answer = self.socket.recv() self.comm_time += time.time() - _start logger.debug('recv: %s', answer) if answer == 'err': process_error(self.socket) return answer == 'tru' def hasDataChanged(self): """Indicates if the current data unit is different than the one at the previous iteration""" logger.debug('send: (hdc) has-data-changed %s %s', self.group.channel, self.name) _start = time.time() self.socket.send('hdc', zmq.SNDMORE) self.socket.send(self.group.channel, zmq.SNDMORE) self.socket.send(self.name) answer = self.socket.recv() self.comm_time += time.time() - _start logger.debug('recv: %s', answer) if answer == 'err': process_error(self.socket) return answer == 'tru' def next(self): """Retrieves the next block of data""" logger.debug('send: (nxt) next %s %s', self.group.channel, self.name) _start = time.time() self.socket.send('nxt', zmq.SNDMORE) self.socket.send(self.group.channel, zmq.SNDMORE) self.socket.send(self.name) answer = self.socket.recv() if answer == 'err': self.comm_time += time.time() - _start process_error(self.socket) self.data_index = int(answer) self.data_index_end = int(self.socket.recv()) self.unpack(self.socket.recv()) self.comm_time += time.time() - _start self.nb_data_blocks_read += 1 def unpack(self, packed): """Receives data through socket""" logger.debug('recv: (size=%d), indexes=(%d, %d)', len(packed), self.data_index, self.data_index_end) if self.unpack: self.data = self.data_format.type() self.data.unpack(packed) else: self.data = packed #---------------------------------------------------------- class InputGroup: """Represents a group of inputs synchronized together The inputs can be either "local" ones (reading data from the cache) or "remote" ones (using a socket to communicate with a database view output located inside a docker container). The other end of the socket must be managed by a message handler (see :py:class:`beat.backend.python.message_handler.MessageHandler`) A group implementing this interface is provided to the algorithms (see :py:class:`beat.backend.python.inputs.InputList`). See :py:class:`beat.core.inputs.Input` Example: .. code-block:: python inputs = InputList() print inputs['labels'].data_format for index in range(0, len(inputs)): print inputs[index].data_format for input in inputs: print input.data_format for input in inputs[0:2]: print input.data_format Parameters: channel (str): Name of the data channel of the group synchronization_listener (beat.core.outputs.SynchronizationListener): Synchronization listener to use restricted_access (bool): Indicates if the algorithm can freely use the inputs Attributes: data_index (int): Index of the last block of data received on the inputs (see the section *Inputs synchronization* of the User's Guide) data_index_end (int): End index of the last block of data received on the inputs (see the section *Inputs synchronization* of the User's Guide) channel (str): Name of the data channel of the group synchronization_listener (beat.core.outputs.SynchronizationListener): Synchronization listener used """ def __init__(self, channel, synchronization_listener=None, restricted_access=True): self._inputs = [] self.data_index = -1 self.data_index_end = -1 self.channel = str(channel) self.synchronization_listener = synchronization_listener self.restricted_access = restricted_access self.socket = None self.comm_time = 0. def __getitem__(self, index): if isinstance(index, six.string_types): try: return [x for x in self._inputs if x.name == index][0] except: pass elif isinstance(index, int): if index < len(self._inputs): return self._inputs[index] return None def __iter__(self): for k in self._inputs: yield k def __len__(self): return len(self._inputs) def add(self, input): """Add an input to the group Parameters: input (beat.backend.python.inputs.Input or beat.backend.python.inputs.RemoteInput): The input to add """ if isinstance(input, RemoteInput) and (self.socket is None): self.socket = input.socket input.group = self self._inputs.append(input) def localInputs(self): for k in [ x for x in self._inputs if isinstance(x, Input) ]: yield k def remoteInputs(self): for k in [ x for x in self._inputs if isinstance(x, RemoteInput) ]: yield k def hasMoreData(self): """Indicates if there is more data to process in the group""" # First process the local inputs res = bool([x for x in self.localInputs() if x.hasMoreData()]) if res: return True # Next process the remote inputs if self.socket is None: return False logger.debug('send: (hmd) has-more-data %s', self.channel) _start = time.time() self.socket.send('hmd', zmq.SNDMORE) self.socket.send(self.channel) answer = self.socket.recv() self.comm_time += time.time() - _start logger.debug('recv: %s', answer) if answer == 'err': process_error(self.socket) return answer == 'tru' def next(self): """Retrieve the next block of data on all the inputs""" # Only for groups not managed by the platform if self.restricted_access: raise RuntimeError('Not authorized') # Only retrieve new data on the inputs where the current data expire first lower_end_index = reduce(lambda x, y: min(x, y.data_index_end), self._inputs[1:], self._inputs[0].data_index_end) inputs_to_update = [x for x in self._inputs \ if x.data_index_end == lower_end_index] inputs_up_to_date = [x for x in self._inputs if x not in inputs_to_update] for input in [ x for x in inputs_to_update if isinstance(x, Input) ]: input.next() input.data_same_as_previous = False remote_inputs_to_update = list([ x for x in inputs_to_update if isinstance(x, RemoteInput) ]) if len(remote_inputs_to_update) > 0: logger.debug('send: (nxt) next %s', self.channel) self.socket.send('nxt', zmq.SNDMORE) self.socket.send(self.channel) # read all incomming data _start = time.time() more = True parts = [] while more: parts.append(self.socket.recv()) if parts[-1] == 'err': self.comm_time += time.time() - _start process_error(self.socket) more = self.socket.getsockopt(zmq.RCVMORE) n = int(parts.pop(0)) logger.debug('recv: %d (inputs)', n) for k in range(n): name = parts.pop(0) logger.debug('recv: %s (data follows)', name) inpt = self[name] if inpt is None: raise RuntimeError("Could not find input `%s' at input group for " \ "channel `%s' while performing `next' operation on this group " \ "(current reading position is %d/%d)" % \ (name, self.channel, k, n)) inpt.data_index = int(parts.pop(0)) inpt.data_index_end = int(parts.pop(0)) inpt.unpack(parts.pop(0)) inpt.nb_data_blocks_read += 1 self.comm_time += time.time() - _start for input in inputs_up_to_date: input.data_same_as_previous = True # Compute the group's start and end indices self.data_index = reduce(lambda x, y: max(x, y.data_index), self._inputs[1:], self._inputs[0].data_index) self.data_index_end = reduce(lambda x, y: min(x, y.data_index_end), self._inputs[1:], self._inputs[0].data_index_end) # Inform the synchronisation listener if self.synchronization_listener is not None: self.synchronization_listener.onIntervalChanged(self.data_index, self.data_index_end) #---------------------------------------------------------- class InputList: """Represents the list of inputs of a processing block Inputs are organized by groups. The inputs inside a group are all synchronized together (see the section *Inputs synchronization* of the User's Guide). A list implementing this interface is provided to the algorithms One group of inputs is always considered as the **main** one, and is used to drive the algorithm. The usage of the other groups is left to the algorithm. See :py:class:`beat.core.inputs.Input` See :py:class:`beat.core.inputs.InputGroup` Example: .. code-block:: python inputs = InputList() ... # Retrieve an input by name input = inputs['labels'] # Retrieve an input by index for index in range(0, len(inputs)): input = inputs[index] # Iteration over all inputs for input in inputs: ... # Iteration over some inputs for input in inputs[0:2]: ... # Retrieve the group an input belongs to, by input name group = inputs.groupOf('label') # Retrieve the group an input belongs to input = inputs['labels'] group = input.group Attributes: group (beat.core.inputs.InputGroup): Main group (for data-driven algorithms) """ def __init__(self): self._groups = [] self.main_group = None def add(self, group): """Add a group to the list :param beat.core.platform.inputs.InputGroup group: The group to add """ if group.restricted_access and (self.main_group is None): self.main_group = group self._groups.append(group) def __getitem__(self, index): if isinstance(index, six.string_types): try: return [k for k in map(lambda x: x[index], self._groups) \ if k is not None][0] except: pass elif isinstance(index, int): for group in self._groups: if index < len(group): return group[index] index -= len(group) return None def __iter__(self): for i in range(len(self)): yield self[i] def __len__(self): return reduce(lambda x, y: x + len(y), self._groups, 0) def nbGroups(self): return len(self._groups) def groupOf(self, input_name): try: return [k for k in self._groups if k[input_name] is not None][0] except: return None def hasMoreData(self): """Indicates if there is more data to process in any group""" return bool([x for x in self._groups if x.hasMoreData()]) def group(self, name_or_index): if isinstance(name_or_index, six.string_types): try: return [x for x in self._groups if x.channel == name_or_index][0] except: return None elif isinstance(name_or_index, int): return self._groups[name_or_index] else: return None