dbexecution.py 9.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.backend.python module of the BEAT platform.   #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################


'''Execution utilities'''

import os
import sys
import glob
import errno
import tempfile
import subprocess

import logging
logger = logging.getLogger(__name__)

import simplejson

# from . import schema
from . import database
from . import inputs
from . import outputs
from . import data
from . import message_handler


class DBExecutor(object):
  """Executor specialised in database views


  Parameters:

    prefix (str): Establishes the prefix of your installation.

    data (dict, str): The piece of data representing the block to be executed.
      It must validate against the schema defined for execution blocks. If a
      string is passed, it is supposed to be a fully qualified absolute path to
      a JSON file containing the block execution information.

    dataformat_cache (dict, optional): A dictionary mapping dataformat names to
      loaded dataformats. This parameter is optional and, if passed, may
      greatly speed-up database loading times as dataformats that are already
      loaded may be re-used. If you use this parameter, you must guarantee that
      the cache is refreshed as appropriate in case the underlying dataformats
      change.

    database_cache (dict, optional): A dictionary mapping database names to
      loaded databases. This parameter is optional and, if passed, may
      greatly speed-up database loading times as databases that are already
      loaded may be re-used. If you use this parameter, you must guarantee that
      the cache is refreshed as appropriate in case the underlying databases
      change.


  Attributes:

    errors (list): A list containing errors found while loading this execution
      block.

    data (dict): The original data for this executor, as loaded by our JSON
      decoder.

    databases (dict): A dictionary in which keys are strings with database
      names and values are :py:class:`database.Database`, representing the
      databases required for running this block. The dictionary may be empty
      in case all inputs are taken from the file cache.

    views (dict): A dictionary in which the keys are tuples pointing to the
      ``(<database-name>, <protocol>, <set>)`` and the value is a setup view
      for that particular combination of details. The dictionary may be empty
      in case all inputs are taken from the file cache.

    input_list (beat.core.inputs.InputList): A list of inputs that will be
      served to the algorithm.

    data_sources (list): A list with all data-sources created by our execution
      loader.

  """

  def __init__(self, prefix, data, dataformat_cache=None, database_cache=None):

    self.prefix = prefix

    # some attributes
    self.databases = {}
    self.views = {}
    self.input_list = None
    self.data_sources = []
    self.handler = None
    self.errors = []
    self.data = None

    # temporary caches, if the user has not set them, for performance
    database_cache = database_cache if database_cache is not None else {}
    self.dataformat_cache = dataformat_cache if dataformat_cache is not None else {}

Philip ABBET's avatar
Philip ABBET committed
122
    self._load(data, database_cache)
123
124


Philip ABBET's avatar
Philip ABBET committed
125
  def _load(self, data, database_cache):
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    """Loads the block execution information"""

    # reset
    self.data = None
    self.errors = []
    self.databases = {}
    self.views = {}
    self.input_list = None
    self.data_sources = []

    if not isinstance(data, dict): #user has passed a file name
      if not os.path.exists(data):
        self.errors.append('File not found: %s' % data)
        return

      with open(data) as f:
        self.data = simplejson.load(f)
    else:
      self.data = data

    # this runs basic validation, including JSON loading if required
    # self.data, self.errors = schema.validate('execution', data)
    # if self.errors: return #don't proceed with the rest of validation

    # load databases
    for name, details in self.data['inputs'].items():
      if 'database' in details:

        if details['database'] not in self.databases:

          if details['database'] in database_cache: #reuse
            db = database_cache[details['database']]
          else: #load it
            db = database.Database(self.prefix, details['database'],
                    dataformat_cache)
            database_cache[db.name] = db

          self.databases[details['database']] = db

          if not db.valid:
            self.errors += db.errors
            continue

        if not db.valid:
          # do not add errors again
          continue

        # create and load the required views
        key = (details['database'], details['protocol'], details['set'])
        if key not in self.views:
          view = self.databases[details['database']].view(details['protocol'],
                  details['set'])

          if details['channel'] == self.data['channel']: #synchronized
            start_index, end_index = self.data.get('range', (None, None))
          else:
            start_index, end_index = (None, None)
          view.prepare_outputs()
          self.views[key] = (view, start_index, end_index)


  def __enter__(self):
    """Prepares inputs and outputs for the processing task

    Raises:

      IOError: in case something cannot be properly setup

    """

    self._prepare_inputs()

    # The setup() of a database view may call isConnected() on an input
    # to set the index at the right location when parallelization is enabled.
    # This is why setup() should be called after initialized the inputs.
    for key, (view, start_index, end_index) in self.views.items():

      if (start_index is None) and (end_index is None):
        status = view.setup()
      else:
        status = view.setup(force_start_index=start_index,
                            force_end_index=end_index)

      if not status:
        raise RuntimeError("Could not setup database view `%s'" % key)

    return self


  def __exit__(self, exc_type, exc_value, traceback):
    """Closes all sinks and disconnects inputs and outputs
    """
    self.input_list = None
    self.data_sources = []


  def _prepare_inputs(self):
    """Prepares all input required by the execution."""

    self.input_list = inputs.InputList()

    # This is used for parallelization purposes
    start_index, end_index = self.data.get('range', (None, None))

    for name, details in self.data['inputs'].items():

      if 'database' in details: #it is a dataset input

        view_key = (details['database'], details['protocol'], details['set'])
        view = self.views[view_key][0]

        data_source = data.MemoryDataSource(view.done, next_callback=view.next)
        self.data_sources.append(data_source)
        output = view.outputs[details['output']]

        # if it's a synchronized channel, makes the output start at the right
        # index, otherwise, it gets lost
        if start_index is not None and \
                details['channel'] == self.data['channel']:
          output.last_written_data_index = start_index - 1
        output.data_sink.data_sources.append(data_source)

        # Synchronization bits
        group = self.input_list.group(details['channel'])
        if group is None:
          group = inputs.InputGroup(
                  details['channel'],
                  synchronization_listener=outputs.SynchronizationListener(),
                  restricted_access=(details['channel'] == self.data['channel'])
                  )
          self.input_list.add(group)

        input_db = self.databases[details['database']]
        input_dataformat_name = input_db.set(details['protocol'], details['set'])['outputs'][details['output']]
        group.add(inputs.Input(name, self.dataformat_cache[input_dataformat_name], data_source))


  def process(self, zmq_context, zmq_socket):

    self.handler = message_handler.MessageHandler(self.input_list, zmq_context, zmq_socket)
    self.handler.start()


  @property
  def valid(self):
    """A boolean that indicates if this executor is valid or not"""

    return not bool(self.errors)


  def wait(self):
    self.handler.join()
    self.handler = None


  def __str__(self):
    return simplejson.dumps(self.data, indent=4)