helpers.py 15.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.backend.python module of the BEAT platform.   #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################

28
29
30
31
32
33
34
"""
=======
helpers
=======

This module implements various helper methods and classes
"""
35

36
import os
37
import errno
38
39
import logging

40
from .data import CachedDataSource
41
from .data import RemoteDataSource
42
43
44
45
46
47
48
49
50
51
52
from .data import CachedDataSink
from .data import getAllFilenames
from .data_loaders import DataLoaderList
from .data_loaders import DataLoader
from .inputs import InputList
from .inputs import Input
from .inputs import InputGroup
from .outputs import SynchronizationListener
from .outputs import OutputList
from .outputs import Output
from .algorithm import Algorithm
53

54
55
logger = logging.getLogger(__name__)

56

57
# ----------------------------------------------------------
58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def convert_loop_to_container(config):
    data = {
        'algorithm': config['algorithm'],
        'parameters': config['parameters'],
        'channel': config['channel'],
        'uid': os.getuid()
    }

    if 'inputs' in config:
        data['inputs'] = \
          dict([(k, { 'channel': v['channel'],
                      'path': v['path'],
                      'database': 'database' in v })
               for k,v in config['inputs'].items()])

    for item in ['request', 'answer']:
        data[item] = {
            'path': config[item]['path'],
            'channel': config[item]['channel']
        }

    return data


83

84
def convert_experiment_configuration_to_container(config):
Philip ABBET's avatar
Philip ABBET committed
85
    data = {
86
87
88
89
        'algorithm': config['algorithm'],
        'parameters': config['parameters'],
        'channel': config['channel'],
        'uid': os.getuid(),
Philip ABBET's avatar
Philip ABBET committed
90
    }
91

Philip ABBET's avatar
Philip ABBET committed
92
93
    if 'range' in config:
        data['range'] = config['range']
94

Philip ABBET's avatar
Philip ABBET committed
95
    data['inputs'] = \
96
          dict([(k, { 'channel': v['channel'], 'path': v['path'], 'database': 'database' in v }) for k,v in config['inputs'].items()])
97

Philip ABBET's avatar
Philip ABBET committed
98
99
100
101
102
    if 'outputs' in config:
        data['outputs'] = \
              dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['outputs'].items()])
    else:
        data['result'] = { 'channel': config['channel'], 'path': config['result']['path'] }
103

104
    if 'loop' in config:
105
        data['loop'] = convert_loop_to_container(config['loop'])
106

Philip ABBET's avatar
Philip ABBET committed
107
    return data
108
109


110
# ----------------------------------------------------------
111
112


113
class AccessMode:
114
    """Possible access modes"""
Philip ABBET's avatar
Philip ABBET committed
115
116
117
    NONE   = 0
    LOCAL  = 1
    REMOTE = 2
118
119
120


def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
121
122
123
                                     cache_access=AccessMode.NONE,
                                     db_access=AccessMode.NONE,
                                     unpack=True, socket=None,
124
125
                                     databases=None,
                                     no_synchronisation_listeners=False):
126

Philip ABBET's avatar
Philip ABBET committed
127
128
    data_sources = []
    views = {}
129
130
    input_list = InputList()
    data_loader_list = DataLoaderList()
131

Philip ABBET's avatar
Philip ABBET committed
132
133
    # This is used for parallelization purposes
    start_index, end_index = config.get('range', (None, None))
134

135
136

    def _create_local_input(details):
137
        data_source = CachedDataSource()
138
139
140
141
142
143
144

        filename = os.path.join(cache_root, details['path'] + '.data')

        if details['channel'] == config['channel']: # synchronized
            status = data_source.setup(
                      filename=filename,
                      prefix=prefix,
145
146
                      start_index=start_index,
                      end_index=end_index,
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
                      unpack=True,
                     )
        else:
            status = data_source.setup(
                      filename=filename,
                      prefix=prefix,
                      unpack=True,
                     )

        if not status:
            raise IOError("cannot load cache file `%s'" % details['path'])

        input = Input(name, algorithm.input_map[name], data_source)

        logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], algorithm.input_map[name], filename))

        return input


167
    def _get_data_loader_for(details):
168
169
170
171
172
173
174
        data_loader = data_loader_list[details['channel']]
        if data_loader is None:
            data_loader = DataLoader(details['channel'])
            data_loader_list.add(data_loader)

            logger.debug("Data loader created: group='%s'" % details['channel'])

175
176
177
178
179
180
181
182
183
184
        return data_loader


    def _create_data_source(details):
        data_loader = _get_data_loader_for(details)

        filename = os.path.join(cache_root, details['path'] + '.data')

        data_source = CachedDataSource()
        result = data_source.setup(
185
186
187
188
189
190
191
192
193
194
            filename=filename,
            prefix=prefix,
            start_index=start_index,
            end_index=end_index,
            unpack=True,
        )

        if not result:
            raise IOError("cannot load cache file `%s'" % details['path'])

195
        data_loader.add(name, data_source)
196
197
198
199
200

        logger.debug("Input '%s' added to data loader: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], algorithm.input_map[name], filename))


Philip ABBET's avatar
Philip ABBET committed
201
    for name, details in config['inputs'].items():
202

203
204
        input = None

Philip ABBET's avatar
Philip ABBET committed
205
206
207
208
        if details.get('database', False):
            if db_access == AccessMode.LOCAL:
                if databases is None:
                    raise IOError("No databases provided")
209

Philip ABBET's avatar
Philip ABBET committed
210
211
212
213
214
                # Retrieve the database
                try:
                    db = databases[details['database']]
                except:
                    raise IOError("Database '%s' not found" % details['database'])
215

Philip ABBET's avatar
Philip ABBET committed
216
217
                # Create of retrieve the database view
                channel = details['channel']
218

219
                if channel not in views:
Philip ABBET's avatar
Philip ABBET committed
220
                    view = db.view(details['protocol'], details['set'])
221
                    view.setup(os.path.join(cache_root, details['path']), pack=False,
222
223
                               start_index=start_index, end_index=end_index)

Philip ABBET's avatar
Philip ABBET committed
224
                    views[channel] = view
225

Philip ABBET's avatar
Philip ABBET committed
226
227
228
229
230
                    logger.debug("Database view '%s/%s/%s' created: group='%s'" % \
                                    (details['database'], details['protocol'], details['set'],
                                     channel))
                else:
                    view = views[channel]
231

232
233
234
235
236
237
238
239
240
241
242
243
244
                data_source = view.data_sources[details['output']]

                if (algorithm.type == Algorithm.LEGACY) or \
                   ((algorithm.type == Algorithm.SEQUENTIAL) and (details['channel'] == config['channel'])):
                    input = Input(name, algorithm.input_map[name], data_source)

                    logger.debug("Input '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
                                    (name, details['channel'], algorithm.input_map[name],
                                    details['database'], details['protocol'], details['set'],
                                    details['output']))
                else:
                    data_loader = _get_data_loader_for(details)
                    data_loader.add(name, data_source)
245

246
247
248
                    logger.debug("DatabaseOutputDataSource '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
                                    (name, channel, algorithm.input_map[name], details['database'],
                                     details['protocol'], details['set'], details['output']))
249

Philip ABBET's avatar
Philip ABBET committed
250
251
            elif db_access == AccessMode.REMOTE:
                if socket is None:
252
253
254
255
256
257
258
259
260
261
262
263
264
                    raise IOError("No socket provided for remote data sources")

                data_source = RemoteDataSource()
                result = data_source.setup(
                    socket=socket,
                    input_name=name,
                    dataformat_name=algorithm.input_map[name],
                    prefix=prefix,
                    unpack=True
                )

                if not result:
                    raise IOError("cannot setup remote data source '%s'" % name)
265
266


267
268
269
270
                if (algorithm.type == Algorithm.LEGACY) or \
                   ((algorithm.type == Algorithm.SEQUENTIAL) and (details['channel'] == config['channel'])):
                    input = Input(name, algorithm.input_map[name], data_source)

Philip ABBET's avatar
Philip ABBET committed
271
                    logger.debug("Input '%s' created: group='%s', dataformat='%s', database-file='%s'" % \
272
                                    (name, details['channel'], algorithm.input_map[name],
Philip ABBET's avatar
Philip ABBET committed
273
                                    details['path']))
274
275
276
277
278
279
                else:
                    data_loader = _get_data_loader_for(details)
                    data_loader.add(name, data_source)

                    logger.debug("RemoteDataSource '%s' created: group='%s', dataformat='%s', connected to a database" % \
                                    (name, details['channel'], algorithm.input_map[name]))
280

281

Philip ABBET's avatar
Philip ABBET committed
282
        elif cache_access == AccessMode.LOCAL:
283

284
285
            if algorithm.type == Algorithm.LEGACY:
                input = _create_local_input(details)
286

287
288
289
290
            elif algorithm.type == Algorithm.SEQUENTIAL:
                if details['channel'] == config['channel']: # synchronized
                    input = _create_local_input(details)
                else:
291
                    _create_data_source(details)
292

293
            else: # Algorithm.AUTONOMOUS or LOOP:
294
                _create_data_source(details)
295

Philip ABBET's avatar
Philip ABBET committed
296
297
        else:
            continue
298

Philip ABBET's avatar
Philip ABBET committed
299
        # Synchronization bits
300
301
302
303
304
305
        if input is not None:
            group = input_list.group(details['channel'])
            if group is None:
                synchronization_listener = None
                if not no_synchronisation_listeners:
                    synchronization_listener = SynchronizationListener()
306

307
308
309
310
311
312
313
                group = InputGroup(
                          details['channel'],
                          synchronization_listener=synchronization_listener,
                          restricted_access=(details['channel'] == config['channel'])
                        )
                input_list.add(group)
                logger.debug("Group '%s' created" % details['channel'])
314

315
            group.add(input)
316

317
    return (input_list, data_loader_list)
318
319


320
# ----------------------------------------------------------
321
322


323
324
def create_outputs_from_configuration(config, algorithm, prefix, cache_root,
                                      input_list=None, data_loaders=None):
325

Philip ABBET's avatar
Philip ABBET committed
326
    data_sinks = []
327
    output_list = OutputList()
328

Philip ABBET's avatar
Philip ABBET committed
329
330
    # This is used for parallelization purposes
    start_index, end_index = config.get('range', (None, None))
331

Philip ABBET's avatar
Philip ABBET committed
332
333
334
335
336
337
338
    # If the algorithm is an analyser
    if 'result' in config:
        output_config = {
            'result': config['result']
        }
    else:
        output_config = config['outputs']
339
340


Philip ABBET's avatar
Philip ABBET committed
341
    for name, details in output_config.items():
342

343
344
        synchronization_listener = None

Philip ABBET's avatar
Philip ABBET committed
345
346
347
348
349
350
        if 'result' in config:
            dataformat_name = 'analysis:' + algorithm.name
            dataformat = algorithm.result_dataformat()
        else:
            dataformat_name = algorithm.output_map[name]
            dataformat =  algorithm.dataformats[dataformat_name]
351

352
353
        if input_list is not None:
            input_group = input_list.group(config['channel'])
354
355
            if input_group is not None:
                synchronization_listener = input_group.synchronization_listener
356

357
358
359
360
361
362
363
364
365
366
367
368
        path = os.path.join(cache_root, details['path'] + '.data')
        dirname = os.path.dirname(path)
        # Make sure that the directory exists while taking care of race
        # conditions. see: http://stackoverflow.com/questions/273192/check-if-a-directory-exists-and-create-it-if-necessary
        try:
            if (len(dirname) > 0):
                os.makedirs(dirname)
        except OSError as exception:
            if exception.errno != errno.EEXIST:
                raise

        if start_index is None:
369
370
            input_path = None

371
            for k, v in config['inputs'].items():
372
373
374
375
                if v['channel'] != config['channel']:
                    continue

                if 'database' not in v:
376
377
378
                    input_path = os.path.join(cache_root, v['path'] + '.data')
                    break

379
380
381
382
383
384
385
386
387
388
389
390
391
392
            if input_path is not None:
                (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \
                        getAllFilenames(input_path)

                end_indices = [ int(x.split('.')[-2]) for x in indices_filenames ]
                end_indices.sort()

                start_index = 0
                end_index = end_indices[-1]

            else:
                for k, v in config['inputs'].items():
                    if v['channel'] != config['channel']:
                        continue
393

394
                    start_index = 0
395

396
397
398
399
400
401
                    if (input_list is not None) and (input_list[k] is not None):
                        end_index = input_list[k].data_source.last_data_index()
                        break
                    elif data_loaders is not None:
                        end_index = data_loaders.main_loader.data_index_end
                        break
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420

        data_sink = CachedDataSink()
        data_sinks.append(data_sink)

        status = data_sink.setup(
            filename=path,
            dataformat=dataformat,
            start_index=start_index,
            end_index=end_index,
            encoding='binary'
        )

        if not status:
            raise IOError("Cannot create cache sink '%s'" % details['path'])

        output_list.add(Output(name, data_sink,
                               synchronization_listener=synchronization_listener,
                               force_start_index=start_index)
        )
421

422
423
424
        if 'result' not in config:
            logger.debug("Output '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                            (name, details['channel'], dataformat_name, path))
Philip ABBET's avatar
Philip ABBET committed
425
        else:
426
427
            logger.debug("Output '%s' created: dataformat='%s', filename='%s'" % \
                            (name, dataformat_name, path))
428

Philip ABBET's avatar
Philip ABBET committed
429
    return (output_list, data_sinks)