helpers.py 15.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.backend.python module of the BEAT platform.   #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################

28
29
30
31
32
33
34
"""
=======
helpers
=======

This module implements various helper methods and classes
"""
35

36
import os
37
import errno
38
39
import logging

40
from .data import CachedDataSource
41
from .data import RemoteDataSource
42
43
44
45
46
47
48
49
50
51
52
from .data import CachedDataSink
from .data import getAllFilenames
from .data_loaders import DataLoaderList
from .data_loaders import DataLoader
from .inputs import InputList
from .inputs import Input
from .inputs import InputGroup
from .outputs import SynchronizationListener
from .outputs import OutputList
from .outputs import Output
from .algorithm import Algorithm
53

54
55
logger = logging.getLogger(__name__)

56

57
# ----------------------------------------------------------
58

59

60
def convert_experiment_configuration_to_container(config):
Philip ABBET's avatar
Philip ABBET committed
61
    data = {
62
63
64
65
        'algorithm': config['algorithm'],
        'parameters': config['parameters'],
        'channel': config['channel'],
        'uid': os.getuid(),
Philip ABBET's avatar
Philip ABBET committed
66
    }
67

Philip ABBET's avatar
Philip ABBET committed
68
69
    if 'range' in config:
        data['range'] = config['range']
70

Philip ABBET's avatar
Philip ABBET committed
71
    data['inputs'] = \
72
          dict([(k, { 'channel': v['channel'], 'path': v['path'], 'database': 'database' in v }) for k,v in config['inputs'].items()])
73

Philip ABBET's avatar
Philip ABBET committed
74
75
76
77
78
    if 'outputs' in config:
        data['outputs'] = \
              dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['outputs'].items()])
    else:
        data['result'] = { 'channel': config['channel'], 'path': config['result']['path'] }
79

80
81
82
83
    if 'loop' in config:
        data['loop'] = \
            dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['loop'].items()])

Philip ABBET's avatar
Philip ABBET committed
84
    return data
85
86


87
# ----------------------------------------------------------
88
89


90
class AccessMode:
91
    """Possible access modes"""
Philip ABBET's avatar
Philip ABBET committed
92
93
94
    NONE   = 0
    LOCAL  = 1
    REMOTE = 2
95
96
97


def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
98
99
100
                                     cache_access=AccessMode.NONE,
                                     db_access=AccessMode.NONE,
                                     unpack=True, socket=None,
101
102
                                     databases=None,
                                     no_synchronisation_listeners=False):
103

Philip ABBET's avatar
Philip ABBET committed
104
105
    data_sources = []
    views = {}
106
107
    input_list = InputList()
    data_loader_list = DataLoaderList()
108

Philip ABBET's avatar
Philip ABBET committed
109
110
    # This is used for parallelization purposes
    start_index, end_index = config.get('range', (None, None))
111

112
113

    def _create_local_input(details):
114
        data_source = CachedDataSource()
115
116
117
118
119
120
121

        filename = os.path.join(cache_root, details['path'] + '.data')

        if details['channel'] == config['channel']: # synchronized
            status = data_source.setup(
                      filename=filename,
                      prefix=prefix,
122
123
                      start_index=start_index,
                      end_index=end_index,
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
                      unpack=True,
                     )
        else:
            status = data_source.setup(
                      filename=filename,
                      prefix=prefix,
                      unpack=True,
                     )

        if not status:
            raise IOError("cannot load cache file `%s'" % details['path'])

        input = Input(name, algorithm.input_map[name], data_source)

        logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], algorithm.input_map[name], filename))

        return input


144
    def _get_data_loader_for(details):
145
146
147
148
149
150
151
        data_loader = data_loader_list[details['channel']]
        if data_loader is None:
            data_loader = DataLoader(details['channel'])
            data_loader_list.add(data_loader)

            logger.debug("Data loader created: group='%s'" % details['channel'])

152
153
154
155
156
157
158
159
160
161
        return data_loader


    def _create_data_source(details):
        data_loader = _get_data_loader_for(details)

        filename = os.path.join(cache_root, details['path'] + '.data')

        data_source = CachedDataSource()
        result = data_source.setup(
162
163
164
165
166
167
168
169
170
171
            filename=filename,
            prefix=prefix,
            start_index=start_index,
            end_index=end_index,
            unpack=True,
        )

        if not result:
            raise IOError("cannot load cache file `%s'" % details['path'])

172
        data_loader.add(name, data_source)
173
174
175
176
177

        logger.debug("Input '%s' added to data loader: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], algorithm.input_map[name], filename))


Philip ABBET's avatar
Philip ABBET committed
178
    for name, details in config['inputs'].items():
179

180
181
        input = None

Philip ABBET's avatar
Philip ABBET committed
182
183
184
185
        if details.get('database', False):
            if db_access == AccessMode.LOCAL:
                if databases is None:
                    raise IOError("No databases provided")
186

Philip ABBET's avatar
Philip ABBET committed
187
188
189
190
191
                # Retrieve the database
                try:
                    db = databases[details['database']]
                except:
                    raise IOError("Database '%s' not found" % details['database'])
192

Philip ABBET's avatar
Philip ABBET committed
193
194
                # Create of retrieve the database view
                channel = details['channel']
195

196
                if channel not in views:
Philip ABBET's avatar
Philip ABBET committed
197
                    view = db.view(details['protocol'], details['set'])
198
                    view.setup(os.path.join(cache_root, details['path']), pack=False,
199
200
                               start_index=start_index, end_index=end_index)

Philip ABBET's avatar
Philip ABBET committed
201
                    views[channel] = view
202

Philip ABBET's avatar
Philip ABBET committed
203
204
205
206
207
                    logger.debug("Database view '%s/%s/%s' created: group='%s'" % \
                                    (details['database'], details['protocol'], details['set'],
                                     channel))
                else:
                    view = views[channel]
208

209
210
211
212
213
214
215
216
217
218
219
220
221
                data_source = view.data_sources[details['output']]

                if (algorithm.type == Algorithm.LEGACY) or \
                   ((algorithm.type == Algorithm.SEQUENTIAL) and (details['channel'] == config['channel'])):
                    input = Input(name, algorithm.input_map[name], data_source)

                    logger.debug("Input '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
                                    (name, details['channel'], algorithm.input_map[name],
                                    details['database'], details['protocol'], details['set'],
                                    details['output']))
                else:
                    data_loader = _get_data_loader_for(details)
                    data_loader.add(name, data_source)
222

223
224
225
                    logger.debug("DatabaseOutputDataSource '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
                                    (name, channel, algorithm.input_map[name], details['database'],
                                     details['protocol'], details['set'], details['output']))
226

Philip ABBET's avatar
Philip ABBET committed
227
228
            elif db_access == AccessMode.REMOTE:
                if socket is None:
229
230
231
232
233
234
235
236
237
238
239
240
241
                    raise IOError("No socket provided for remote data sources")

                data_source = RemoteDataSource()
                result = data_source.setup(
                    socket=socket,
                    input_name=name,
                    dataformat_name=algorithm.input_map[name],
                    prefix=prefix,
                    unpack=True
                )

                if not result:
                    raise IOError("cannot setup remote data source '%s'" % name)
242
243


244
245
246
247
                if (algorithm.type == Algorithm.LEGACY) or \
                   ((algorithm.type == Algorithm.SEQUENTIAL) and (details['channel'] == config['channel'])):
                    input = Input(name, algorithm.input_map[name], data_source)

Philip ABBET's avatar
Philip ABBET committed
248
                    logger.debug("Input '%s' created: group='%s', dataformat='%s', database-file='%s'" % \
249
                                    (name, details['channel'], algorithm.input_map[name],
Philip ABBET's avatar
Philip ABBET committed
250
                                    details['path']))
251
252
253
254
255
256
                else:
                    data_loader = _get_data_loader_for(details)
                    data_loader.add(name, data_source)

                    logger.debug("RemoteDataSource '%s' created: group='%s', dataformat='%s', connected to a database" % \
                                    (name, details['channel'], algorithm.input_map[name]))
257

258

Philip ABBET's avatar
Philip ABBET committed
259
        elif cache_access == AccessMode.LOCAL:
260

261
262
            if algorithm.type == Algorithm.LEGACY:
                input = _create_local_input(details)
263

264
265
266
267
            elif algorithm.type == Algorithm.SEQUENTIAL:
                if details['channel'] == config['channel']: # synchronized
                    input = _create_local_input(details)
                else:
268
                    _create_data_source(details)
269

270
            else: # Algorithm.AUTONOMOUS or LOOP:
271
                _create_data_source(details)
272

Philip ABBET's avatar
Philip ABBET committed
273
274
        else:
            continue
275

Philip ABBET's avatar
Philip ABBET committed
276
        # Synchronization bits
277
278
279
280
281
282
        if input is not None:
            group = input_list.group(details['channel'])
            if group is None:
                synchronization_listener = None
                if not no_synchronisation_listeners:
                    synchronization_listener = SynchronizationListener()
283

284
285
286
287
288
289
290
                group = InputGroup(
                          details['channel'],
                          synchronization_listener=synchronization_listener,
                          restricted_access=(details['channel'] == config['channel'])
                        )
                input_list.add(group)
                logger.debug("Group '%s' created" % details['channel'])
291

292
            group.add(input)
293

294
    return (input_list, data_loader_list)
295
296


297
# ----------------------------------------------------------
298
299


300
301
def create_outputs_from_configuration(config, algorithm, prefix, cache_root,
                                      input_list=None, data_loaders=None):
302

Philip ABBET's avatar
Philip ABBET committed
303
    data_sinks = []
304
    output_list = OutputList()
305

Philip ABBET's avatar
Philip ABBET committed
306
307
    # This is used for parallelization purposes
    start_index, end_index = config.get('range', (None, None))
308

Philip ABBET's avatar
Philip ABBET committed
309
310
311
312
313
314
315
    # If the algorithm is an analyser
    if 'result' in config:
        output_config = {
            'result': config['result']
        }
    else:
        output_config = config['outputs']
316
317


Philip ABBET's avatar
Philip ABBET committed
318
    for name, details in output_config.items():
319

320
321
        synchronization_listener = None

Philip ABBET's avatar
Philip ABBET committed
322
323
324
325
326
327
        if 'result' in config:
            dataformat_name = 'analysis:' + algorithm.name
            dataformat = algorithm.result_dataformat()
        else:
            dataformat_name = algorithm.output_map[name]
            dataformat =  algorithm.dataformats[dataformat_name]
328

329
330
        if input_list is not None:
            input_group = input_list.group(config['channel'])
331
332
            if input_group is not None:
                synchronization_listener = input_group.synchronization_listener
333

334
335
336
337
338
339
340
341
342
343
344
345
        path = os.path.join(cache_root, details['path'] + '.data')
        dirname = os.path.dirname(path)
        # Make sure that the directory exists while taking care of race
        # conditions. see: http://stackoverflow.com/questions/273192/check-if-a-directory-exists-and-create-it-if-necessary
        try:
            if (len(dirname) > 0):
                os.makedirs(dirname)
        except OSError as exception:
            if exception.errno != errno.EEXIST:
                raise

        if start_index is None:
346
347
            input_path = None

348
            for k, v in config['inputs'].items():
349
350
351
352
                if v['channel'] != config['channel']:
                    continue

                if 'database' not in v:
353
354
355
                    input_path = os.path.join(cache_root, v['path'] + '.data')
                    break

356
357
358
359
360
361
362
363
364
365
366
367
368
369
            if input_path is not None:
                (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \
                        getAllFilenames(input_path)

                end_indices = [ int(x.split('.')[-2]) for x in indices_filenames ]
                end_indices.sort()

                start_index = 0
                end_index = end_indices[-1]

            else:
                for k, v in config['inputs'].items():
                    if v['channel'] != config['channel']:
                        continue
370

371
                    start_index = 0
372

373
374
375
376
377
378
                    if (input_list is not None) and (input_list[k] is not None):
                        end_index = input_list[k].data_source.last_data_index()
                        break
                    elif data_loaders is not None:
                        end_index = data_loaders.main_loader.data_index_end
                        break
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397

        data_sink = CachedDataSink()
        data_sinks.append(data_sink)

        status = data_sink.setup(
            filename=path,
            dataformat=dataformat,
            start_index=start_index,
            end_index=end_index,
            encoding='binary'
        )

        if not status:
            raise IOError("Cannot create cache sink '%s'" % details['path'])

        output_list.add(Output(name, data_sink,
                               synchronization_listener=synchronization_listener,
                               force_start_index=start_index)
        )
398

399
400
401
        if 'result' not in config:
            logger.debug("Output '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                            (name, details['channel'], dataformat_name, path))
Philip ABBET's avatar
Philip ABBET committed
402
        else:
403
404
            logger.debug("Output '%s' created: dataformat='%s', filename='%s'" % \
                            (name, dataformat_name, path))
405

Philip ABBET's avatar
Philip ABBET committed
406
    return (output_list, data_sinks)