executor.py 8.16 KB
Newer Older
André Anjos's avatar
André Anjos committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.backend.python module of the BEAT platform.   #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################


'''A class that can setup and execute algorithm blocks on the backend'''

import logging
logger = logging.getLogger(__name__)

import os
import time

import zmq
import simplejson

from . import algorithm
from . import inputs
from . import outputs


class Executor(object):
  """Executors runs the code given an execution block information

  Parameters:

    socket (zmq.Socket): A pre-connected socket to send and receive messages
      from.

    directory (str): The path to a directory containing all the information
      required to run the user experiment.

    dataformat_cache (dict, optional): A dictionary mapping dataformat names to
      loaded dataformats. This parameter is optional and, if passed, may
      greatly speed-up database loading times as dataformats that are already
      loaded may be re-used. If you use this parameter, you must guarantee that
      the cache is refreshed as appropriate in case the underlying dataformats
      change.

    database_cache (dict, optional): A dictionary mapping database names to
      loaded databases. This parameter is optional and, if passed, may
      greatly speed-up database loading times as databases that are already
      loaded may be re-used. If you use this parameter, you must guarantee that
      the cache is refreshed as appropriate in case the underlying databases
      change.

    library_cache (dict, optional): A dictionary mapping library names to
      loaded libraries. This parameter is optional and, if passed, may greatly
      speed-up library loading times as libraries that are already loaded may
      be re-used. If you use this parameter, you must guarantee that the cache
      is refreshed as appropriate in case the underlying libraries change.
  """

  def __init__(self, socket, directory, dataformat_cache=None,
          database_cache=None, library_cache=None):

    self.socket = socket
    self.comm_time = 0. #total communication time

    self.configuration = os.path.join(directory, 'configuration.json')
    with open(self.configuration, 'rb') as f: self.data = simplejson.load(f)
    self.prefix = os.path.join(directory, 'prefix')
    self.runner = None

    # temporary caches, if the user has not set them, for performance
    database_cache = database_cache if database_cache is not None else {}
    dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
    library_cache = library_cache if library_cache is not None else {}

    self.algorithm = algorithm.Algorithm(self.prefix, self.data['algorithm'],
            dataformat_cache, library_cache)

    # Use algorithm names for inputs and outputs
    main_channel = self.data['channel']

    # Loads algorithm inputs
    if 'inputs' in self.data:
      self.input_list = inputs.InputList()
      for name, channel in self.data['inputs'].items():
        group = self.input_list.group(channel)
        if group is None:
          group = inputs.InputGroup(channel, (channel == main_channel),
                      socket=self.socket)
          self.input_list.add(group)
        thisformat = self.algorithm.dataformats[self.algorithm.input_map[name]]
        group.add(inputs.Input(name, thisformat, self.socket))
      logger.debug("Loaded input list with %d group(s) and %d input(s)",
          self.input_list.nbGroups(), len(self.input_list))

    # Loads outputs
    if 'outputs' in self.data:
      self.output_list = outputs.OutputList()
      for name, channel in self.data['outputs'].items():
        thisformat = self.algorithm.dataformats[self.algorithm.output_map[name]]
        self.output_list.add(outputs.Output(name, thisformat, self.socket))
      logger.debug("Loaded output list with %d output(s)",
          len(self.output_list))

    # Loads results if it is an analyzer
    if 'result' in self.data:
      self.output_list = outputs.OutputList()
      name = 'result'
      # Retrieve dataformats in the JSON of the algorithm
      analysis_format = self.algorithm.result_dataformat()
      analysis_format.name = 'analysis:' + self.algorithm.name
      self.output_list.add(outputs.Output(name, analysis_format, self.socket))
      logger.debug("Loaded output list for analyzer (1 single output)")


  def setup(self):
    """Sets up the algorithm to start processing"""

    self.runner = self.algorithm.runner()
    retval = self.runner.setup(self.data['parameters'])
    logger.debug("User algorithm is setup")
    return retval


  def process(self):
    """Executes the user algorithm code using the current interpreter.
    """

    if not self.input_list or not self.output_list:
      raise RuntimeError("I/O for execution block has not yet been set up")

    using_output = self.output_list[0] if self.analysis else self.output_list

    _start = time.time()

    while self.input_list.hasMoreData():
      main_group = self.input_list.main_group
      main_group.restricted_access = False
      main_group.next()
      main_group.restricted_access = True
      if not self.runner.process(self.input_list, using_output): return False

    missing_data_outputs = [x for x in self.output_list if x.isDataMissing()]

    proc_time = time.time() - _start

    if missing_data_outputs:
      raise RuntimeError("Missing data on the following output(s): %s" % \
              ', '.join([x.name for x in missing_data_outputs]))

    self.comm_time = sum([x.comm_time for x in self.input_list]) + \
            sum([x.comm_time for x in self.output_list])
    self.comm_time += sum([self.input_list[k].comm_time for k in range(self.input_list.nbGroups())])

    # some local information
    logger.debug("Total processing time was %.3f seconds" , proc_time)
    logger.debug("Time spent in I/O was %.3f seconds" , self.comm_time)
    logger.debug("I/O/Processing ratio is %d%%",
            100*self.comm_time/proc_time)

    # Handle the done command
    self.done()

    return True


  def done(self):
    """Indicates the infrastructure the execution is done"""

    logger.debug('send: (don) done')
    self.socket.send('don', zmq.SNDMORE)
    self.socket.send('%.6e' % self.comm_time)
    answer = self.socket.recv() #ack
    logger.debug('recv: %s', answer)


  @property
  def schema_version(self):
    """Returns the schema version"""
    return self.data.get('schema_version', 1)


  @property
  def analysis(self):
    """A boolean that indicates if the current block is an analysis block"""
    return 'result' in self.data