helpers.py 8.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.backend.python module of the BEAT platform.   #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################


29
import os
30
import errno
31
32
33
34

import logging
logger = logging.getLogger(__name__)

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from . import data
from . import inputs
from . import outputs



def convert_experiment_configuration_to_container(config, proxy_mode):
  data = {
    'proxy_mode': proxy_mode,
    'algorithm': config['algorithm'],
    'parameters': config['parameters'],
    'channel': config['channel'],
  }

  if 'range' in config:
    data['range'] = config['range']

  data['inputs'] = \
53
        dict([(k, { 'channel': v['channel'], 'path': v['path'], 'database': v.has_key('database') }) for k,v in config['inputs'].items()])
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

  if 'outputs' in config:
    data['outputs'] = \
          dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['outputs'].items()])
  else:
    data['result'] = { 'channel': config['channel'], 'path': config['result']['path'] }

  return data


#----------------------------------------------------------


class CacheAccess:
  NONE   = 0
  LOCAL  = 1
  REMOTE = 2


def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
                                     cache_access=CacheAccess.NONE, unpack=True,
                                     socket=None):

  data_sources = []
  input_list = inputs.InputList()

  # This is used for parallelization purposes
  start_index, end_index = config.get('range', (None, None))

  for name, details in config['inputs'].items():

85
    if details.get('database', False):
86
87
88
89
90
91
      if socket is None:
        raise IOError("No socket provided for remote inputs")

      input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
                                 socket, unpack=unpack)

92
93
94
95
      logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s', connected to a database" % \
                      (name, details['channel'], algorithm.input_map[name]))

    elif cache_access == CacheAccess.LOCAL:
96
97
98
      data_source = data.CachedDataSource()
      data_sources.append(data_source)

99
100
      filename = os.path.join(cache_root, details['path'] + '.data')

101
102
      if details['channel'] == config['channel']: # synchronized
        status = data_source.setup(
103
                  filename=filename,
104
105
106
107
108
109
110
                  prefix=prefix,
                  force_start_index=start_index,
                  force_end_index=end_index,
                  unpack=unpack,
                 )
      else:
        status = data_source.setup(
111
                  filename=filename,
112
113
114
115
116
117
118
                  prefix=prefix,
                  unpack=unpack,
                 )

      if not status:
        raise IOError("cannot load cache file `%s'" % details['path'])

119
120
121
122
123
124
125
126
127
128
129
130
131
132
      input = inputs.Input(name, algorithm.input_map[name], data_source)

      logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                      (name, details['channel'], algorithm.input_map[name], filename))

    elif cache_access == CacheAccess.REMOTE:
      if socket is None:
        raise IOError("No socket provided for remote inputs")

      input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
                                 socket, unpack=unpack)

      logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s'" % \
                      (name, details['channel'], algorithm.input_map[name]))
133
134
135
136
137
138
139
140
141
142
143
144
145

    else:
      continue

    # Synchronization bits
    group = input_list.group(details['channel'])
    if group is None:
      group = inputs.InputGroup(
                details['channel'],
                synchronization_listener=outputs.SynchronizationListener(),
                restricted_access=(details['channel'] == config['channel'])
              )
      input_list.add(group)
146
      logger.debug("Group '%s' created" % details['channel'])
147
148
149
150

    group.add(input)

  return (input_list, data_sources)
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243



#----------------------------------------------------------


def create_outputs_from_configuration(config, algorithm, prefix, cache_root, input_list,
                                      cache_access=CacheAccess.NONE, socket=None):

  data_sinks = []
  output_list = outputs.OutputList()

  # This is used for parallelization purposes
  start_index, end_index = config.get('range', (None, None))

  # If the algorithm is an analyser
  if 'result' in config:
    output_config = {
      'result': config['result']
    }
  else:
    output_config = config['outputs']


  for name, details in output_config.items():

    if 'result' in config:
      dataformat_name = 'analysis:' + algorithm.name
      dataformat = algorithm.result_dataformat()
    else:
      dataformat_name = algorithm.output_map[name]
      dataformat =  algorithm.dataformats[dataformat_name]


    if cache_access == CacheAccess.LOCAL:

      path = os.path.join(cache_root, details['path'] + '.data')
      dirname = os.path.dirname(path)
      # Make sure that the directory exists while taking care of race
      # conditions. see: http://stackoverflow.com/questions/273192/check-if-a-directory-exists-and-create-it-if-necessary
      try:
        if (len(dirname) > 0):
          os.makedirs(dirname)
      except OSError as exception:
        if exception.errno != errno.EEXIST:
          raise

      data_sink = data.CachedDataSink()
      data_sinks.append(data_sink)

      status = data_sink.setup(
          filename=path,
          dataformat=dataformat,
          encoding='binary',
          max_size=0, # in bytes, for individual file chunks
      )

      if not status:
        raise IOError("Cannot create cache sink '%s'" % details['path'])


      synchronization_listener = None

      if 'result' not in config:
        input_group = input_list.group(details['channel'])
        if (input_group is not None) and hasattr(input_group, 'synchronization_listener'):
          synchronization_listener = input_group.synchronization_listener

      output_list.add(outputs.Output(name, data_sink,
          synchronization_listener=synchronization_listener,
          force_start_index=start_index or 0)
      )

      if 'result' not in config:
        logger.debug("Output '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], dataformat_name, path))
      else:
        logger.debug("Output '%s' created: dataformat='%s', filename='%s'" % \
                        (name, dataformat_name, path))

    elif cache_access == CacheAccess.REMOTE:
      if socket is None:
        raise IOError("No socket provided for remote outputs")

      output_list.add(outputs.RemoteOutput(name, dataformat, socket))

      logger.debug("RemoteOutput '%s' created: group='%s', dataformat='%s'" % \
                      (name, details['channel'], dataformat_name))

    else:
      continue

  return (output_list, data_sinks)