helpers.py 10 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.backend.python module of the BEAT platform.   #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################


29
import os
30
import errno
31
32
33
34

import logging
logger = logging.getLogger(__name__)

35
36
37
38
39
40
41
42
43
44
45
46
from . import data
from . import inputs
from . import outputs



def convert_experiment_configuration_to_container(config, proxy_mode):
  data = {
    'proxy_mode': proxy_mode,
    'algorithm': config['algorithm'],
    'parameters': config['parameters'],
    'channel': config['channel'],
Philip ABBET's avatar
Philip ABBET committed
47
    'uid': os.getuid(),
48
49
50
51
52
53
  }

  if 'range' in config:
    data['range'] = config['range']

  data['inputs'] = \
54
        dict([(k, { 'channel': v['channel'], 'path': v['path'], 'database': v.has_key('database') }) for k,v in config['inputs'].items()])
55
56
57
58
59
60
61
62
63
64
65
66
67

  if 'outputs' in config:
    data['outputs'] = \
          dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['outputs'].items()])
  else:
    data['result'] = { 'channel': config['channel'], 'path': config['result']['path'] }

  return data


#----------------------------------------------------------


68
class AccessMode:
69
70
71
72
73
74
  NONE   = 0
  LOCAL  = 1
  REMOTE = 2


def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
75
76
77
78
                                     cache_access=AccessMode.NONE,
                                     db_access=AccessMode.NONE,
                                     unpack=True, socket=None,
                                     databases=None):
79
80

  data_sources = []
81
  views = {}
82
83
84
85
86
87
88
  input_list = inputs.InputList()

  # This is used for parallelization purposes
  start_index, end_index = config.get('range', (None, None))

  for name, details in config['inputs'].items():

89
    if details.get('database', False):
90
91
92
      if db_access == AccessMode.LOCAL:
        if databases is None:
          raise IOError("No databases provided")
93

94
95
96
97
98
        # Retrieve the database
        try:
          db = databases[details['database']]
        except:
          raise IOError("Database '%s' not found" % details['database'])
99

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        # Create of retrieve the database view
        channel = details['channel']

        if not views.has_key(channel):
          view = db.view(details['protocol'], details['set'])
          view.prepare_outputs()
          view.setup()
          views[channel] = view

          logger.debug("Database view '%s/%s/%s' created: group='%s'" % \
                          (details['database'], details['protocol'], details['set'],
                           channel))
        else:
          view = views[channel]

        # Creation of the input
        data_source = data.MemoryDataSource(view.done, next_callback=view.next)

        output = view.outputs[details['output']]
        output.data_sink.data_sources.append(data_source)

        input = inputs.Input(name, algorithm.input_map[name], data_source)

        logger.debug("Input '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
                        (name, channel, algorithm.input_map[name], details['database'],
                         details['protocol'], details['set'], details['output']))

      elif db_access == AccessMode.REMOTE:
        if socket is None:
          raise IOError("No socket provided for remote inputs")

        input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
                                   socket, unpack=unpack)

        logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s', connected to a database" % \
                        (name, details['channel'], algorithm.input_map[name]))
136

137
    elif cache_access == AccessMode.LOCAL:
138
139
140
      data_source = data.CachedDataSource()
      data_sources.append(data_source)

141
142
      filename = os.path.join(cache_root, details['path'] + '.data')

143
144
      if details['channel'] == config['channel']: # synchronized
        status = data_source.setup(
145
                  filename=filename,
146
147
148
                  prefix=prefix,
                  force_start_index=start_index,
                  force_end_index=end_index,
149
                  unpack=True,
150
151
152
                 )
      else:
        status = data_source.setup(
153
                  filename=filename,
154
                  prefix=prefix,
155
                  unpack=True,
156
157
158
159
160
                 )

      if not status:
        raise IOError("cannot load cache file `%s'" % details['path'])

161
162
163
164
165
      input = inputs.Input(name, algorithm.input_map[name], data_source)

      logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                      (name, details['channel'], algorithm.input_map[name], filename))

166
    elif cache_access == AccessMode.REMOTE:
167
168
169
170
171
172
173
174
      if socket is None:
        raise IOError("No socket provided for remote inputs")

      input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
                                 socket, unpack=unpack)

      logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s'" % \
                      (name, details['channel'], algorithm.input_map[name]))
175
176
177
178
179
180
181
182
183
184
185
186
187

    else:
      continue

    # Synchronization bits
    group = input_list.group(details['channel'])
    if group is None:
      group = inputs.InputGroup(
                details['channel'],
                synchronization_listener=outputs.SynchronizationListener(),
                restricted_access=(details['channel'] == config['channel'])
              )
      input_list.add(group)
188
      logger.debug("Group '%s' created" % details['channel'])
189
190
191
192

    group.add(input)

  return (input_list, data_sources)
193
194
195
196
197
198
199



#----------------------------------------------------------


def create_outputs_from_configuration(config, algorithm, prefix, cache_root, input_list,
200
                                      cache_access=AccessMode.NONE, socket=None):
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

  data_sinks = []
  output_list = outputs.OutputList()

  # This is used for parallelization purposes
  start_index, end_index = config.get('range', (None, None))

  # If the algorithm is an analyser
  if 'result' in config:
    output_config = {
      'result': config['result']
    }
  else:
    output_config = config['outputs']


  for name, details in output_config.items():

    if 'result' in config:
      dataformat_name = 'analysis:' + algorithm.name
      dataformat = algorithm.result_dataformat()
    else:
      dataformat_name = algorithm.output_map[name]
      dataformat =  algorithm.dataformats[dataformat_name]


227
    if cache_access == AccessMode.LOCAL:
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272

      path = os.path.join(cache_root, details['path'] + '.data')
      dirname = os.path.dirname(path)
      # Make sure that the directory exists while taking care of race
      # conditions. see: http://stackoverflow.com/questions/273192/check-if-a-directory-exists-and-create-it-if-necessary
      try:
        if (len(dirname) > 0):
          os.makedirs(dirname)
      except OSError as exception:
        if exception.errno != errno.EEXIST:
          raise

      data_sink = data.CachedDataSink()
      data_sinks.append(data_sink)

      status = data_sink.setup(
          filename=path,
          dataformat=dataformat,
          encoding='binary',
          max_size=0, # in bytes, for individual file chunks
      )

      if not status:
        raise IOError("Cannot create cache sink '%s'" % details['path'])


      synchronization_listener = None

      if 'result' not in config:
        input_group = input_list.group(details['channel'])
        if (input_group is not None) and hasattr(input_group, 'synchronization_listener'):
          synchronization_listener = input_group.synchronization_listener

      output_list.add(outputs.Output(name, data_sink,
          synchronization_listener=synchronization_listener,
          force_start_index=start_index or 0)
      )

      if 'result' not in config:
        logger.debug("Output '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], dataformat_name, path))
      else:
        logger.debug("Output '%s' created: dataformat='%s', filename='%s'" % \
                        (name, dataformat_name, path))

273
    elif cache_access == AccessMode.REMOTE:
274
275
276
277
278
279
280
281
282
283
284
285
      if socket is None:
        raise IOError("No socket provided for remote outputs")

      output_list.add(outputs.RemoteOutput(name, dataformat, socket))

      logger.debug("RemoteOutput '%s' created: group='%s', dataformat='%s'" % \
                      (name, details['channel'], dataformat_name))

    else:
      continue

  return (output_list, data_sinks)