helpers.py 15.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.backend.python module of the BEAT platform.   #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################

28
29
30
31
32
33
34
"""
=======
helpers
=======

This module implements various helper methods and classes
"""
35

36
import os
37
import errno
38
39
import logging

40
from .data import CachedDataSource
41
from .data import RemoteDataSource
42
43
44
45
46
47
48
49
50
51
52
from .data import CachedDataSink
from .data import getAllFilenames
from .data_loaders import DataLoaderList
from .data_loaders import DataLoader
from .inputs import InputList
from .inputs import Input
from .inputs import InputGroup
from .outputs import SynchronizationListener
from .outputs import OutputList
from .outputs import Output
from .algorithm import Algorithm
53

54
55
logger = logging.getLogger(__name__)

56

57
# ----------------------------------------------------------
58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def parse_inputs(inputs):
    data = {}
    for key, value in inputs.items():
        data[key] = dict(
                channel=value['channel'],
                path=value['path'],
            )
        if 'database' in value:
            db = dict(
                database=value['database'],
                protocol=value['protocol'],
                set=value['set'],
                output=value['output']
                )
            data[key].update(db)
    return data

76
77
78
79
80
81
82
83
def convert_loop_to_container(config):
    data = {
        'algorithm': config['algorithm'],
        'parameters': config['parameters'],
        'channel': config['channel'],
        'uid': os.getuid()
    }

84
    data['inputs'] = parse_inputs(config['inputs'])
85
86
87
88

    return data


89

90
def convert_experiment_configuration_to_container(config):
Philip ABBET's avatar
Philip ABBET committed
91
    data = {
92
93
94
95
        'algorithm': config['algorithm'],
        'parameters': config['parameters'],
        'channel': config['channel'],
        'uid': os.getuid(),
Philip ABBET's avatar
Philip ABBET committed
96
    }
97

Philip ABBET's avatar
Philip ABBET committed
98
99
    if 'range' in config:
        data['range'] = config['range']
100

101
    data['inputs'] = parse_inputs(config['inputs'])
102

Philip ABBET's avatar
Philip ABBET committed
103
104
105
106
107
    if 'outputs' in config:
        data['outputs'] = \
              dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['outputs'].items()])
    else:
        data['result'] = { 'channel': config['channel'], 'path': config['result']['path'] }
108

109
    if 'loop' in config:
110
        data['loop'] = convert_loop_to_container(config['loop'])
111

Philip ABBET's avatar
Philip ABBET committed
112
    return data
113
114


115
# ----------------------------------------------------------
116
117


118
class AccessMode:
119
    """Possible access modes"""
Philip ABBET's avatar
Philip ABBET committed
120
121
122
    NONE   = 0
    LOCAL  = 1
    REMOTE = 2
123
124
125


def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
126
127
128
                                     cache_access=AccessMode.NONE,
                                     db_access=AccessMode.NONE,
                                     unpack=True, socket=None,
129
130
                                     databases=None,
                                     no_synchronisation_listeners=False):
131

Philip ABBET's avatar
Philip ABBET committed
132
133
    data_sources = []
    views = {}
134
135
    input_list = InputList()
    data_loader_list = DataLoaderList()
136

Philip ABBET's avatar
Philip ABBET committed
137
138
    # This is used for parallelization purposes
    start_index, end_index = config.get('range', (None, None))
139

140
141

    def _create_local_input(details):
142
        data_source = CachedDataSource()
143
144
145
146
147
148
149

        filename = os.path.join(cache_root, details['path'] + '.data')

        if details['channel'] == config['channel']: # synchronized
            status = data_source.setup(
                      filename=filename,
                      prefix=prefix,
150
151
                      start_index=start_index,
                      end_index=end_index,
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
                      unpack=True,
                     )
        else:
            status = data_source.setup(
                      filename=filename,
                      prefix=prefix,
                      unpack=True,
                     )

        if not status:
            raise IOError("cannot load cache file `%s'" % details['path'])

        input = Input(name, algorithm.input_map[name], data_source)

        logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], algorithm.input_map[name], filename))

        return input


172
    def _get_data_loader_for(details):
173
174
175
176
177
178
179
        data_loader = data_loader_list[details['channel']]
        if data_loader is None:
            data_loader = DataLoader(details['channel'])
            data_loader_list.add(data_loader)

            logger.debug("Data loader created: group='%s'" % details['channel'])

180
181
182
183
184
185
186
187
188
189
        return data_loader


    def _create_data_source(details):
        data_loader = _get_data_loader_for(details)

        filename = os.path.join(cache_root, details['path'] + '.data')

        data_source = CachedDataSource()
        result = data_source.setup(
190
191
192
193
194
195
196
197
            filename=filename,
            prefix=prefix,
            start_index=start_index,
            end_index=end_index,
            unpack=True,
        )

        if not result:
198
            raise IOError("cannot load cache file `%s'" % filename)
199

200
        data_loader.add(name, data_source)
201
202
203
204
205

        logger.debug("Input '%s' added to data loader: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], algorithm.input_map[name], filename))


Philip ABBET's avatar
Philip ABBET committed
206
    for name, details in config['inputs'].items():
207

208
209
        input = None

210
        if details.get('database', None) is not None:
Philip ABBET's avatar
Philip ABBET committed
211
212
213
            if db_access == AccessMode.LOCAL:
                if databases is None:
                    raise IOError("No databases provided")
214

Philip ABBET's avatar
Philip ABBET committed
215
216
217
218
219
                # Retrieve the database
                try:
                    db = databases[details['database']]
                except:
                    raise IOError("Database '%s' not found" % details['database'])
220

Philip ABBET's avatar
Philip ABBET committed
221
222
                # Create of retrieve the database view
                channel = details['channel']
223

224
                if channel not in views:
Philip ABBET's avatar
Philip ABBET committed
225
                    view = db.view(details['protocol'], details['set'])
226
                    view.setup(os.path.join(cache_root, details['path']), pack=False,
227
228
                               start_index=start_index, end_index=end_index)

Philip ABBET's avatar
Philip ABBET committed
229
                    views[channel] = view
230

Philip ABBET's avatar
Philip ABBET committed
231
232
233
234
235
                    logger.debug("Database view '%s/%s/%s' created: group='%s'" % \
                                    (details['database'], details['protocol'], details['set'],
                                     channel))
                else:
                    view = views[channel]
236

237
238
239
240
241
242
243
244
245
246
247
248
249
                data_source = view.data_sources[details['output']]

                if (algorithm.type == Algorithm.LEGACY) or \
                   ((algorithm.type == Algorithm.SEQUENTIAL) and (details['channel'] == config['channel'])):
                    input = Input(name, algorithm.input_map[name], data_source)

                    logger.debug("Input '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
                                    (name, details['channel'], algorithm.input_map[name],
                                    details['database'], details['protocol'], details['set'],
                                    details['output']))
                else:
                    data_loader = _get_data_loader_for(details)
                    data_loader.add(name, data_source)
250

251
252
253
                    logger.debug("DatabaseOutputDataSource '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
                                    (name, channel, algorithm.input_map[name], details['database'],
                                     details['protocol'], details['set'], details['output']))
254

Philip ABBET's avatar
Philip ABBET committed
255
256
            elif db_access == AccessMode.REMOTE:
                if socket is None:
257
258
259
260
261
262
263
264
265
266
267
268
269
                    raise IOError("No socket provided for remote data sources")

                data_source = RemoteDataSource()
                result = data_source.setup(
                    socket=socket,
                    input_name=name,
                    dataformat_name=algorithm.input_map[name],
                    prefix=prefix,
                    unpack=True
                )

                if not result:
                    raise IOError("cannot setup remote data source '%s'" % name)
270
271


272
273
274
275
                if (algorithm.type == Algorithm.LEGACY) or \
                   ((algorithm.type == Algorithm.SEQUENTIAL) and (details['channel'] == config['channel'])):
                    input = Input(name, algorithm.input_map[name], data_source)

Philip ABBET's avatar
Philip ABBET committed
276
                    logger.debug("Input '%s' created: group='%s', dataformat='%s', database-file='%s'" % \
277
                                    (name, details['channel'], algorithm.input_map[name],
Philip ABBET's avatar
Philip ABBET committed
278
                                    details['path']))
279
280
281
282
283
284
                else:
                    data_loader = _get_data_loader_for(details)
                    data_loader.add(name, data_source)

                    logger.debug("RemoteDataSource '%s' created: group='%s', dataformat='%s', connected to a database" % \
                                    (name, details['channel'], algorithm.input_map[name]))
285

286

Philip ABBET's avatar
Philip ABBET committed
287
        elif cache_access == AccessMode.LOCAL:
288

289
290
            if algorithm.type == Algorithm.LEGACY:
                input = _create_local_input(details)
291

292
293
294
295
            elif algorithm.type == Algorithm.SEQUENTIAL:
                if details['channel'] == config['channel']: # synchronized
                    input = _create_local_input(details)
                else:
296
                    _create_data_source(details)
297

298
            else: # Algorithm.AUTONOMOUS or LOOP:
299
                _create_data_source(details)
300

Philip ABBET's avatar
Philip ABBET committed
301
302
        else:
            continue
303

Philip ABBET's avatar
Philip ABBET committed
304
        # Synchronization bits
305
306
307
308
309
310
        if input is not None:
            group = input_list.group(details['channel'])
            if group is None:
                synchronization_listener = None
                if not no_synchronisation_listeners:
                    synchronization_listener = SynchronizationListener()
311

312
313
314
315
316
317
318
                group = InputGroup(
                          details['channel'],
                          synchronization_listener=synchronization_listener,
                          restricted_access=(details['channel'] == config['channel'])
                        )
                input_list.add(group)
                logger.debug("Group '%s' created" % details['channel'])
319

320
            group.add(input)
321

322
    return (input_list, data_loader_list)
323
324


325
# ----------------------------------------------------------
326
327


328
329
def create_outputs_from_configuration(config, algorithm, prefix, cache_root,
                                      input_list=None, data_loaders=None):
330

Philip ABBET's avatar
Philip ABBET committed
331
    data_sinks = []
332
    output_list = OutputList()
333

Philip ABBET's avatar
Philip ABBET committed
334
335
    # This is used for parallelization purposes
    start_index, end_index = config.get('range', (None, None))
336

Philip ABBET's avatar
Philip ABBET committed
337
338
339
340
341
342
343
    # If the algorithm is an analyser
    if 'result' in config:
        output_config = {
            'result': config['result']
        }
    else:
        output_config = config['outputs']
344
345


Philip ABBET's avatar
Philip ABBET committed
346
    for name, details in output_config.items():
347

348
349
        synchronization_listener = None

Philip ABBET's avatar
Philip ABBET committed
350
351
352
353
354
355
        if 'result' in config:
            dataformat_name = 'analysis:' + algorithm.name
            dataformat = algorithm.result_dataformat()
        else:
            dataformat_name = algorithm.output_map[name]
            dataformat =  algorithm.dataformats[dataformat_name]
356

357
358
        if input_list is not None:
            input_group = input_list.group(config['channel'])
359
360
            if input_group is not None:
                synchronization_listener = input_group.synchronization_listener
361

362
363
364
365
366
367
368
369
370
371
372
373
        path = os.path.join(cache_root, details['path'] + '.data')
        dirname = os.path.dirname(path)
        # Make sure that the directory exists while taking care of race
        # conditions. see: http://stackoverflow.com/questions/273192/check-if-a-directory-exists-and-create-it-if-necessary
        try:
            if (len(dirname) > 0):
                os.makedirs(dirname)
        except OSError as exception:
            if exception.errno != errno.EEXIST:
                raise

        if start_index is None:
374
375
            input_path = None

376
            for k, v in config['inputs'].items():
377
378
379
380
                if v['channel'] != config['channel']:
                    continue

                if 'database' not in v:
381
382
383
                    input_path = os.path.join(cache_root, v['path'] + '.data')
                    break

384
385
386
387
388
389
390
391
392
393
394
395
396
397
            if input_path is not None:
                (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \
                        getAllFilenames(input_path)

                end_indices = [ int(x.split('.')[-2]) for x in indices_filenames ]
                end_indices.sort()

                start_index = 0
                end_index = end_indices[-1]

            else:
                for k, v in config['inputs'].items():
                    if v['channel'] != config['channel']:
                        continue
398

399
                    start_index = 0
400

401
402
403
404
405
406
                    if (input_list is not None) and (input_list[k] is not None):
                        end_index = input_list[k].data_source.last_data_index()
                        break
                    elif data_loaders is not None:
                        end_index = data_loaders.main_loader.data_index_end
                        break
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425

        data_sink = CachedDataSink()
        data_sinks.append(data_sink)

        status = data_sink.setup(
            filename=path,
            dataformat=dataformat,
            start_index=start_index,
            end_index=end_index,
            encoding='binary'
        )

        if not status:
            raise IOError("Cannot create cache sink '%s'" % details['path'])

        output_list.add(Output(name, data_sink,
                               synchronization_listener=synchronization_listener,
                               force_start_index=start_index)
        )
426

427
428
429
        if 'result' not in config:
            logger.debug("Output '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                            (name, details['channel'], dataformat_name, path))
Philip ABBET's avatar
Philip ABBET committed
430
        else:
431
432
            logger.debug("Output '%s' created: dataformat='%s', filename='%s'" % \
                            (name, dataformat_name, path))
433

Philip ABBET's avatar
Philip ABBET committed
434
    return (output_list, data_sinks)