helpers.py 14.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.backend.python module of the BEAT platform.   #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################

28
29
30
31
32
33
34
"""
=======
helpers
=======

This module implements various helper methods and classes
"""
35

36
import os
37
import errno
38
39
import logging

40
from .data import CachedDataSource
41
from .data import RemoteDataSource
42
43
44
45
46
47
48
49
50
51
52
from .data import CachedDataSink
from .data import getAllFilenames
from .data_loaders import DataLoaderList
from .data_loaders import DataLoader
from .inputs import InputList
from .inputs import Input
from .inputs import InputGroup
from .outputs import SynchronizationListener
from .outputs import OutputList
from .outputs import Output
from .algorithm import Algorithm
53

54
55
logger = logging.getLogger(__name__)

56

57
# ----------------------------------------------------------
58

59

60
def convert_experiment_configuration_to_container(config):
Philip ABBET's avatar
Philip ABBET committed
61
    data = {
62
63
64
65
        'algorithm': config['algorithm'],
        'parameters': config['parameters'],
        'channel': config['channel'],
        'uid': os.getuid(),
Philip ABBET's avatar
Philip ABBET committed
66
    }
67

Philip ABBET's avatar
Philip ABBET committed
68
69
    if 'range' in config:
        data['range'] = config['range']
70

Philip ABBET's avatar
Philip ABBET committed
71
    data['inputs'] = \
72
          dict([(k, { 'channel': v['channel'], 'path': v['path'], 'database': 'database' in v }) for k,v in config['inputs'].items()])
73

Philip ABBET's avatar
Philip ABBET committed
74
75
76
77
78
    if 'outputs' in config:
        data['outputs'] = \
              dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['outputs'].items()])
    else:
        data['result'] = { 'channel': config['channel'], 'path': config['result']['path'] }
79

Philip ABBET's avatar
Philip ABBET committed
80
    return data
81
82


83
# ----------------------------------------------------------
84
85


86
class AccessMode:
87
    """Possible access modes"""
Philip ABBET's avatar
Philip ABBET committed
88
89
90
    NONE   = 0
    LOCAL  = 1
    REMOTE = 2
91
92
93


def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
94
95
96
                                     cache_access=AccessMode.NONE,
                                     db_access=AccessMode.NONE,
                                     unpack=True, socket=None,
97
98
                                     databases=None,
                                     no_synchronisation_listeners=False):
99

Philip ABBET's avatar
Philip ABBET committed
100
101
    data_sources = []
    views = {}
102
103
    input_list = InputList()
    data_loader_list = DataLoaderList()
104

Philip ABBET's avatar
Philip ABBET committed
105
106
    # This is used for parallelization purposes
    start_index, end_index = config.get('range', (None, None))
107

108
109

    def _create_local_input(details):
110
        data_source = CachedDataSource()
111
112
113
114
115
116
117

        filename = os.path.join(cache_root, details['path'] + '.data')

        if details['channel'] == config['channel']: # synchronized
            status = data_source.setup(
                      filename=filename,
                      prefix=prefix,
118
119
                      start_index=start_index,
                      end_index=end_index,
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
                      unpack=True,
                     )
        else:
            status = data_source.setup(
                      filename=filename,
                      prefix=prefix,
                      unpack=True,
                     )

        if not status:
            raise IOError("cannot load cache file `%s'" % details['path'])

        input = Input(name, algorithm.input_map[name], data_source)

        logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], algorithm.input_map[name], filename))

        return input


140
    def _get_data_loader_for(details):
141
142
143
144
145
146
147
        data_loader = data_loader_list[details['channel']]
        if data_loader is None:
            data_loader = DataLoader(details['channel'])
            data_loader_list.add(data_loader)

            logger.debug("Data loader created: group='%s'" % details['channel'])

148
149
150
151
152
153
154
155
156
157
        return data_loader


    def _create_data_source(details):
        data_loader = _get_data_loader_for(details)

        filename = os.path.join(cache_root, details['path'] + '.data')

        data_source = CachedDataSource()
        result = data_source.setup(
158
159
160
161
162
163
164
165
166
167
            filename=filename,
            prefix=prefix,
            start_index=start_index,
            end_index=end_index,
            unpack=True,
        )

        if not result:
            raise IOError("cannot load cache file `%s'" % details['path'])

168
        data_loader.add(name, data_source)
169
170
171
172
173

        logger.debug("Input '%s' added to data loader: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], algorithm.input_map[name], filename))


Philip ABBET's avatar
Philip ABBET committed
174
    for name, details in config['inputs'].items():
175

176
177
        input = None

Philip ABBET's avatar
Philip ABBET committed
178
179
180
181
        if details.get('database', False):
            if db_access == AccessMode.LOCAL:
                if databases is None:
                    raise IOError("No databases provided")
182

Philip ABBET's avatar
Philip ABBET committed
183
184
185
186
187
                # Retrieve the database
                try:
                    db = databases[details['database']]
                except:
                    raise IOError("Database '%s' not found" % details['database'])
188

Philip ABBET's avatar
Philip ABBET committed
189
190
                # Create of retrieve the database view
                channel = details['channel']
191

192
                if channel not in views:
Philip ABBET's avatar
Philip ABBET committed
193
                    view = db.view(details['protocol'], details['set'])
194
                    view.setup(os.path.join(cache_root, details['path']), pack=False,
195
196
                               start_index=start_index, end_index=end_index)

Philip ABBET's avatar
Philip ABBET committed
197
                    views[channel] = view
198

Philip ABBET's avatar
Philip ABBET committed
199
200
201
202
203
                    logger.debug("Database view '%s/%s/%s' created: group='%s'" % \
                                    (details['database'], details['protocol'], details['set'],
                                     channel))
                else:
                    view = views[channel]
204

205
206
207
208
209
210
211
212
213
214
215
216
217
                data_source = view.data_sources[details['output']]

                if (algorithm.type == Algorithm.LEGACY) or \
                   ((algorithm.type == Algorithm.SEQUENTIAL) and (details['channel'] == config['channel'])):
                    input = Input(name, algorithm.input_map[name], data_source)

                    logger.debug("Input '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
                                    (name, details['channel'], algorithm.input_map[name],
                                    details['database'], details['protocol'], details['set'],
                                    details['output']))
                else:
                    data_loader = _get_data_loader_for(details)
                    data_loader.add(name, data_source)
218

219
220
221
                    logger.debug("DatabaseOutputDataSource '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
                                    (name, channel, algorithm.input_map[name], details['database'],
                                     details['protocol'], details['set'], details['output']))
222

Philip ABBET's avatar
Philip ABBET committed
223
224
            elif db_access == AccessMode.REMOTE:
                if socket is None:
225
226
227
228
229
230
231
232
233
234
235
236
237
                    raise IOError("No socket provided for remote data sources")

                data_source = RemoteDataSource()
                result = data_source.setup(
                    socket=socket,
                    input_name=name,
                    dataformat_name=algorithm.input_map[name],
                    prefix=prefix,
                    unpack=True
                )

                if not result:
                    raise IOError("cannot setup remote data source '%s'" % name)
238
239


240
241
242
243
                if (algorithm.type == Algorithm.LEGACY) or \
                   ((algorithm.type == Algorithm.SEQUENTIAL) and (details['channel'] == config['channel'])):
                    input = Input(name, algorithm.input_map[name], data_source)

Philip ABBET's avatar
Philip ABBET committed
244
                    logger.debug("Input '%s' created: group='%s', dataformat='%s', database-file='%s'" % \
245
                                    (name, details['channel'], algorithm.input_map[name],
Philip ABBET's avatar
Philip ABBET committed
246
                                    details['path']))
247
248
249
250
251
252
                else:
                    data_loader = _get_data_loader_for(details)
                    data_loader.add(name, data_source)

                    logger.debug("RemoteDataSource '%s' created: group='%s', dataformat='%s', connected to a database" % \
                                    (name, details['channel'], algorithm.input_map[name]))
253

254

Philip ABBET's avatar
Philip ABBET committed
255
        elif cache_access == AccessMode.LOCAL:
256

257
258
            if algorithm.type == Algorithm.LEGACY:
                input = _create_local_input(details)
259

260
261
262
263
            elif algorithm.type == Algorithm.SEQUENTIAL:
                if details['channel'] == config['channel']: # synchronized
                    input = _create_local_input(details)
                else:
264
                    _create_data_source(details)
265

266
            else: # Algorithm.AUTONOMOUS or LOOP:
267
                _create_data_source(details)
268

Philip ABBET's avatar
Philip ABBET committed
269
270
        else:
            continue
271

Philip ABBET's avatar
Philip ABBET committed
272
        # Synchronization bits
273
274
275
276
277
278
        if input is not None:
            group = input_list.group(details['channel'])
            if group is None:
                synchronization_listener = None
                if not no_synchronisation_listeners:
                    synchronization_listener = SynchronizationListener()
279

280
281
282
283
284
285
286
                group = InputGroup(
                          details['channel'],
                          synchronization_listener=synchronization_listener,
                          restricted_access=(details['channel'] == config['channel'])
                        )
                input_list.add(group)
                logger.debug("Group '%s' created" % details['channel'])
287

288
            group.add(input)
289

290
    return (input_list, data_loader_list)
291
292


293
# ----------------------------------------------------------
294
295


296
297
def create_outputs_from_configuration(config, algorithm, prefix, cache_root,
                                      input_list=None, data_loaders=None):
298

Philip ABBET's avatar
Philip ABBET committed
299
    data_sinks = []
300
    output_list = OutputList()
301

Philip ABBET's avatar
Philip ABBET committed
302
303
    # This is used for parallelization purposes
    start_index, end_index = config.get('range', (None, None))
304

Philip ABBET's avatar
Philip ABBET committed
305
306
307
308
309
310
311
    # If the algorithm is an analyser
    if 'result' in config:
        output_config = {
            'result': config['result']
        }
    else:
        output_config = config['outputs']
312
313


Philip ABBET's avatar
Philip ABBET committed
314
    for name, details in output_config.items():
315

316
317
        synchronization_listener = None

Philip ABBET's avatar
Philip ABBET committed
318
319
320
321
322
323
        if 'result' in config:
            dataformat_name = 'analysis:' + algorithm.name
            dataformat = algorithm.result_dataformat()
        else:
            dataformat_name = algorithm.output_map[name]
            dataformat =  algorithm.dataformats[dataformat_name]
324

325
326
        if input_list is not None:
            input_group = input_list.group(config['channel'])
327
328
            if input_group is not None:
                synchronization_listener = input_group.synchronization_listener
329

330
331
332
333
334
335
336
337
338
339
340
341
        path = os.path.join(cache_root, details['path'] + '.data')
        dirname = os.path.dirname(path)
        # Make sure that the directory exists while taking care of race
        # conditions. see: http://stackoverflow.com/questions/273192/check-if-a-directory-exists-and-create-it-if-necessary
        try:
            if (len(dirname) > 0):
                os.makedirs(dirname)
        except OSError as exception:
            if exception.errno != errno.EEXIST:
                raise

        if start_index is None:
342
343
            input_path = None

344
            for k, v in config['inputs'].items():
345
346
347
348
                if v['channel'] != config['channel']:
                    continue

                if 'database' not in v:
349
350
351
                    input_path = os.path.join(cache_root, v['path'] + '.data')
                    break

352
353
354
355
356
357
358
359
360
361
362
363
364
365
            if input_path is not None:
                (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \
                        getAllFilenames(input_path)

                end_indices = [ int(x.split('.')[-2]) for x in indices_filenames ]
                end_indices.sort()

                start_index = 0
                end_index = end_indices[-1]

            else:
                for k, v in config['inputs'].items():
                    if v['channel'] != config['channel']:
                        continue
366

367
                    start_index = 0
368

369
370
371
372
373
374
                    if (input_list is not None) and (input_list[k] is not None):
                        end_index = input_list[k].data_source.last_data_index()
                        break
                    elif data_loaders is not None:
                        end_index = data_loaders.main_loader.data_index_end
                        break
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393

        data_sink = CachedDataSink()
        data_sinks.append(data_sink)

        status = data_sink.setup(
            filename=path,
            dataformat=dataformat,
            start_index=start_index,
            end_index=end_index,
            encoding='binary'
        )

        if not status:
            raise IOError("Cannot create cache sink '%s'" % details['path'])

        output_list.add(Output(name, data_sink,
                               synchronization_listener=synchronization_listener,
                               force_start_index=start_index)
        )
394

395
396
397
        if 'result' not in config:
            logger.debug("Output '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                            (name, details['channel'], dataformat_name, path))
Philip ABBET's avatar
Philip ABBET committed
398
        else:
399
400
            logger.debug("Output '%s' created: dataformat='%s', filename='%s'" % \
                            (name, dataformat_name, path))
401

Philip ABBET's avatar
Philip ABBET committed
402
    return (output_list, data_sinks)