#!/usr/bin/env python # vim: set fileencoding=utf-8 : ############################################################################### # # # Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ # # Contact: beat.support@idiap.ch # # # # This file is part of the beat.backend.python module of the BEAT platform. # # # # Commercial License Usage # # Licensees holding valid commercial BEAT licenses may use this file in # # accordance with the terms contained in a written agreement between you # # and Idiap. For further information contact tto@idiap.ch # # # # Alternatively, this file may be used under the terms of the GNU Affero # # Public License version 3 as published by the Free Software and appearing # # in the file LICENSE.AGPL included in the packaging of this file. # # The BEAT platform is distributed in the hope that it will be useful, but # # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # # or FITNESS FOR A PARTICULAR PURPOSE. # # # # You should have received a copy of the GNU Affero Public License along # # with the BEAT platform. If not, see http://www.gnu.org/licenses/. # # # ############################################################################### import time import logging logger = logging.getLogger(__name__) from functools import reduce import six import zmq class Input: """Represents the input of a processing block A list of those inputs must be provided to the algorithms (see :py:class:`beat.backend.python.inputs.InputList`) Parameters: name (str): Name of the input data_format (str): Data format accepted by the input data_source (beat.core.platform.data.DataSource): Source of data to be used by the input Attributes: group (beat.core.inputs.InputGroup): Group containing this input name (str): Name of the input (algorithm-specific) data (beat.core.baseformat.baseformat): The last block of data received on the input data_index (int): Index of the last block of data received on the input (see the section *Inputs synchronization* of the User's Guide) data_index_end (int): End index of the last block of data received on the input (see the section *Inputs synchronization* of the User's Guide) data_format (str): Data format accepted by the input data_source (beat.core.data.DataSource): Source of data used by the output nb_data_blocks_read (int): Number of data blocks read so far """ def __init__(self, name, data_format, data_source): self.group = None self.name = name self.data = None self.data_index = -1 self.data_index_end = -1 self.data_same_as_previous = False self.data_format = data_format self.data_source = data_source self.nb_data_blocks_read = 0 def isDataUnitDone(self): """Indicates if the current data unit will change at the next iteration""" return (self.data_index_end == self.group.data_index_end) def hasMoreData(self): """Indicates if there is more data to process on the input""" return self.data_source.hasMoreData() def next(self): """Retrieves the next block of data""" (self.data, self.data_index, self.data_index_end) = self.data_source.next() self.data_same_as_previous = False self.nb_data_blocks_read += 1 #---------------------------------------------------------- class InputGroup: """Represents a group of inputs synchronized together A group implementing this interface is provided to the algorithms (see :py:class:`beat.backend.python.inputs.InputList`). See :py:class:`beat.core.inputs.Input` Example: .. code-block:: python inputs = InputList() print inputs['labels'].data_format for index in range(0, len(inputs)): print inputs[index].data_format for input in inputs: print input.data_format for input in inputs[0:2]: print input.data_format Parameters: channel (str): Name of the data channel of the group synchronization_listener (beat.core.outputs.SynchronizationListener): Synchronization listener to use restricted_access (bool): Indicates if the algorithm can freely use the inputs Atttributes: data_index (int): Index of the last block of data received on the inputs (see the section *Inputs synchronization* of the User's Guide) data_index_end (int): End index of the last block of data received on the inputs (see the section *Inputs synchronization* of the User's Guide) channel (str): Name of the data channel of the group synchronization_listener (beat.core.outputs.SynchronizationListener): Synchronization listener used """ def __init__(self, channel, synchronization_listener=None, restricted_access=True): self._inputs = [] self.data_index = -1 self.data_index_end = -1 self.channel = channel self.synchronization_listener = synchronization_listener self.restricted_access = restricted_access def __getitem__(self, index): if isinstance(index, six.string_types): try: return [x for x in self._inputs if x.name == index][0] except: pass elif isinstance(index, int): if index < len(self._inputs): return self._inputs[index] return None def __iter__(self): for k in self._inputs: yield k def __len__(self): return len(self._inputs) def add(self, input): """Add an input to the group Parameters: input (beat.core.inputs.Input): The input to add """ input.group = self self._inputs.append(input) def hasMoreData(self): """Indicates if there is more data to process in the group""" return bool([x for x in self._inputs if x.hasMoreData()]) def next(self): """Retrieve the next block of data on all the inputs""" # Only for groups not managed by the platform if self.restricted_access: raise RuntimeError('Not authorized') # Only retrieve new data on the inputs where the current data expire first lower_end_index = reduce(lambda x, y: min(x, y.data_index_end), self._inputs[1:], self._inputs[0].data_index_end) inputs_to_update = [x for x in self._inputs \ if x.data_index_end == lower_end_index] inputs_up_to_date = [x for x in self._inputs if x not in inputs_to_update] for input in inputs_to_update: input.next() input.data_same_as_previous = False for input in inputs_up_to_date: input.data_same_as_previous = True # Compute the group's start and end indices self.data_index = reduce(lambda x, y: max(x, y.data_index), self._inputs[1:], self._inputs[0].data_index) self.data_index_end = reduce(lambda x, y: min(x, y.data_index_end), self._inputs[1:], self._inputs[0].data_index_end) # Inform the synchronisation listener if self.synchronization_listener is not None: self.synchronization_listener.onIntervalChanged(self.data_index, self.data_index_end) #---------------------------------------------------------- class RemoteInput: """Allows to access the input of a processing block, via a socket. The other end of the socket must be managed by a message handler (see :py:class:`beat.backend.python.message_handler.MessageHandler`) A list of those inputs must be provided to the algorithms (see :py:class:`beat.backend.python.inputs.InputList`) Parameters: name (str): Name of the input data_format (object): An object with the preloaded data format for this input (see :py:class:`beat.backend.python.dataformat.DataFormat`). socket (object): A 0MQ socket for writing the data to the server process Attributes: group (beat.backend.python.inputs.InputGroup): Group containing this input data (beat.core.baseformat.baseformat): The last block of data received on the input """ def __init__(self, name, data_format, socket): self.name = name self.data_format = data_format self.socket = socket self.data = None self.data_index = -1 self.data_index_end = -1 self.group = None self.comm_time = 0. #total time spent on communication self.nb_data_blocks_read = 0 def isDataUnitDone(self): """Indicates if the current data unit will change at the next iteration""" logger.debug('send: (idd) is-dataunit-done %s', self.name) _start = time.time() self.socket.send('idd', zmq.SNDMORE) self.socket.send(self.group.channel, zmq.SNDMORE) self.socket.send(self.name) answer = self.socket.recv() self.comm_time += time.time() - _start logger.debug('recv: %s', answer) return answer == 'tru' def hasMoreData(self): """Indicates if there is more data to process on the input""" logger.debug('send: (hmd) has-more-data %s %s', self.group.channel, self.name) _start = time.time() self.socket.send('hmd', zmq.SNDMORE) self.socket.send(self.group.channel, zmq.SNDMORE) self.socket.send(self.name) answer = self.socket.recv() self.comm_time += time.time() - _start logger.debug('recv: %s', answer) return answer == 'tru' def next(self): """Retrieves the next block of data""" logger.debug('send: (nxt) next %s %s', self.group.channel, self.name) _start = time.time() self.socket.send('nxt', zmq.SNDMORE) self.socket.send(self.group.channel, zmq.SNDMORE) self.socket.send(self.name) self.data_index = int(self.socket.recv()) self.data_index_end = int(self.socket.recv()) self.unpack(self.socket.recv()) self.comm_time += time.time() - _start self.nb_data_blocks_read += 1 def unpack(self, packed): """Receives data through socket""" self.data = self.data_format.type() logger.debug('recv: (size=%d), indexes=(%d, %d)', len(packed), self.data_index, self.data_index_end) self.data.unpack(packed) #---------------------------------------------------------- class RemoteInputGroup: """Allows to access a group of inputs synchronized together, via a socket. The other end of the socket must be managed by a message handler (see :py:class:`beat.backend.python.message_handler.MessageHandler`) A group implementing this interface is provided to the algorithms (see :py:class:`beat.core.inputs.InputList`). See :py:class:`beat.core.inputs.Input` Example: .. code-block:: python inputs = InputList() print inputs['labels'].data_format for index in range(0, len(inputs)): print inputs[index].data_format for input in inputs: print input.data_format for input in inputs[0:2]: print input.data_format Parameters: channel (str): Name of the data channel of the group restricted_access (bool): Indicates if the algorithm can freely use the inputs socket (object): A 0MQ socket for writing the data to the server process """ def __init__(self, channel, restricted_access, socket): self._inputs = [] self.channel = channel self.restricted_access = restricted_access self.socket = socket self.comm_time = 0. def __getitem__(self, index): if isinstance(index, six.string_types): try: return [x for x in self._inputs if x.name == index][0] except: pass elif isinstance(index, int): if index < len(self._inputs): return self._inputs[index] return None def __iter__(self): for k in self._inputs: yield k def __len__(self): return len(self._inputs) def add(self, input): """Add an input to the group Parameters: input (beat.backend.python.inputs.Input): The input to add """ input.group = self self._inputs.append(input) def hasMoreData(self): """Indicates if there is more data to process in the group""" logger.debug('send: (hmd) has-more-data %s', self.channel) _start = time.time() self.socket.send('hmd', zmq.SNDMORE) self.socket.send(self.channel) answer = self.socket.recv() self.comm_time += time.time() - _start logger.debug('recv: %s', answer) return answer == 'tru' def next(self): """Retrieve the next block of data on all the inputs""" logger.debug('send: (nxt) next %s', self.channel) self.socket.send('nxt', zmq.SNDMORE) self.socket.send(self.channel) # read all incomming data _start = time.time() more = True parts = [] while more: parts.append(self.socket.recv()) more = self.socket.getsockopt(zmq.RCVMORE) n = int(parts.pop(0)) logger.debug('recv: %d (inputs)', n) for k in range(n): name = parts.pop(0) logger.debug('recv: %s (data follows)', name) inpt = self[name] if inpt is None: raise RuntimeError("Could not find input `%s' at input group for " \ "channel `%s' while performing `next' operation on this group " \ "(current reading position is %d/%d)" % \ (name, self.channel, k, n)) inpt.data_index = int(parts.pop(0)) inpt.data_index_end = int(parts.pop(0)) inpt.unpack(parts.pop(0)) inpt.nb_data_blocks_read += 1 self.comm_time += time.time() - _start #---------------------------------------------------------- class InputList: """Represents the list of inputs of a processing block Inputs are organized by groups. The inputs inside a group are all synchronized together (see the section *Inputs synchronization* of the User's Guide). A list implementing this interface is provided to the algorithms One group of inputs is always considered as the **main** one, and is used to drive the algorithm. The usage of the other groups is left to the algorithm. See :py:class:`beat.core.inputs.Input` See :py:class:`beat.core.inputs.InputGroup` Example: .. code-block:: python inputs = InputList() ... # Retrieve an input by name input = inputs['labels'] # Retrieve an input by index for index in range(0, len(inputs)): input = inputs[index] # Iteration over all inputs for input in inputs: ... # Iteration over some inputs for input in inputs[0:2]: ... # Retrieve the group an input belongs to, by input name group = inputs.groupOf('label') # Retrieve the group an input belongs to input = inputs['labels'] group = input.group Attributes: group (beat.core.inputs.InputGroup): Main group (for data-driven algorithms) """ def __init__(self): self._groups = [] self.main_group = None def add(self, group): """Add a group to the list :param beat.core.platform.inputs.InputGroup group: The group to add """ if group.restricted_access and (self.main_group is None): self.main_group = group self._groups.append(group) def __getitem__(self, index): if isinstance(index, six.string_types): try: return [k for k in map(lambda x: x[index], self._groups) \ if k is not None][0] except: pass elif isinstance(index, int): for group in self._groups: if index < len(group): return group[index] index -= len(group) return None def __iter__(self): for i in range(len(self)): yield self[i] def __len__(self): return reduce(lambda x, y: x + len(y), self._groups, 0) def nbGroups(self): return len(self._groups) def groupOf(self, input_name): try: return [k for k in self._groups if k[input_name] is not None][0] except: return None def hasMoreData(self): """Indicates if there is more data to process in any group""" return bool([x for x in self._groups if x.hasMoreData()]) def group(self, name): try: return [x for x in self._groups if x.channel == name][0] except: return None