helpers.py 16.4 KB
Newer Older
1
2
3
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

Samuel GAIST's avatar
Samuel GAIST committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
###################################################################################
#                                                                                 #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/               #
# Contact: beat.support@idiap.ch                                                  #
#                                                                                 #
# Redistribution and use in source and binary forms, with or without              #
# modification, are permitted provided that the following conditions are met:     #
#                                                                                 #
# 1. Redistributions of source code must retain the above copyright notice, this  #
# list of conditions and the following disclaimer.                                #
#                                                                                 #
# 2. Redistributions in binary form must reproduce the above copyright notice,    #
# this list of conditions and the following disclaimer in the documentation       #
# and/or other materials provided with the distribution.                          #
#                                                                                 #
# 3. Neither the name of the copyright holder nor the names of its contributors   #
# may be used to endorse or promote products derived from this software without   #
# specific prior written permission.                                              #
#                                                                                 #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED   #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE          #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE    #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL      #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR      #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER      #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,   #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE   #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.            #
#                                                                                 #
###################################################################################

36

37
38
39
40
41
42
43
"""
=======
helpers
=======

This module implements various helper methods and classes
"""
44

45
import os
46
import errno
47
48
import logging

49
from .data import CachedDataSource
50
from .data import RemoteDataSource
51
52
53
54
55
56
57
58
59
60
61
from .data import CachedDataSink
from .data import getAllFilenames
from .data_loaders import DataLoaderList
from .data_loaders import DataLoader
from .inputs import InputList
from .inputs import Input
from .inputs import InputGroup
from .outputs import SynchronizationListener
from .outputs import OutputList
from .outputs import Output
from .algorithm import Algorithm
62

63
64
logger = logging.getLogger(__name__)

65

66
# ----------------------------------------------------------
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def parse_inputs(inputs):
    data = {}
    for key, value in inputs.items():
        data[key] = dict(
                channel=value['channel'],
                path=value['path'],
            )
        if 'database' in value:
            db = dict(
                database=value['database'],
                protocol=value['protocol'],
                set=value['set'],
                output=value['output']
                )
            data[key].update(db)
    return data

85
86
87
88
89
90
91
92
def convert_loop_to_container(config):
    data = {
        'algorithm': config['algorithm'],
        'parameters': config['parameters'],
        'channel': config['channel'],
        'uid': os.getuid()
    }

93
    data['inputs'] = parse_inputs(config['inputs'])
94
95
96
97

    return data


98

99
def convert_experiment_configuration_to_container(config):
Philip ABBET's avatar
Philip ABBET committed
100
    data = {
101
102
103
104
        'algorithm': config['algorithm'],
        'parameters': config['parameters'],
        'channel': config['channel'],
        'uid': os.getuid(),
Philip ABBET's avatar
Philip ABBET committed
105
    }
106

Philip ABBET's avatar
Philip ABBET committed
107
108
    if 'range' in config:
        data['range'] = config['range']
109

110
    data['inputs'] = parse_inputs(config['inputs'])
111

Philip ABBET's avatar
Philip ABBET committed
112
113
114
115
116
    if 'outputs' in config:
        data['outputs'] = \
              dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['outputs'].items()])
    else:
        data['result'] = { 'channel': config['channel'], 'path': config['result']['path'] }
117

118
    if 'loop' in config:
119
        data['loop'] = convert_loop_to_container(config['loop'])
120

Philip ABBET's avatar
Philip ABBET committed
121
    return data
122
123


124
# ----------------------------------------------------------
125
126


127
class AccessMode:
128
    """Possible access modes"""
Philip ABBET's avatar
Philip ABBET committed
129
130
131
    NONE   = 0
    LOCAL  = 1
    REMOTE = 2
132
133
134


def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
135
136
137
                                     cache_access=AccessMode.NONE,
                                     db_access=AccessMode.NONE,
                                     unpack=True, socket=None,
138
139
                                     databases=None,
                                     no_synchronisation_listeners=False):
140

Philip ABBET's avatar
Philip ABBET committed
141
142
    data_sources = []
    views = {}
143
144
    input_list = InputList()
    data_loader_list = DataLoaderList()
145

Philip ABBET's avatar
Philip ABBET committed
146
147
    # This is used for parallelization purposes
    start_index, end_index = config.get('range', (None, None))
148

149
150

    def _create_local_input(details):
151
        data_source = CachedDataSource()
152
153
154
155
156
157
158

        filename = os.path.join(cache_root, details['path'] + '.data')

        if details['channel'] == config['channel']: # synchronized
            status = data_source.setup(
                      filename=filename,
                      prefix=prefix,
159
160
                      start_index=start_index,
                      end_index=end_index,
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
                      unpack=True,
                     )
        else:
            status = data_source.setup(
                      filename=filename,
                      prefix=prefix,
                      unpack=True,
                     )

        if not status:
            raise IOError("cannot load cache file `%s'" % details['path'])

        input = Input(name, algorithm.input_map[name], data_source)

        logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], algorithm.input_map[name], filename))

        return input


181
    def _get_data_loader_for(details):
182
183
184
185
186
187
188
        data_loader = data_loader_list[details['channel']]
        if data_loader is None:
            data_loader = DataLoader(details['channel'])
            data_loader_list.add(data_loader)

            logger.debug("Data loader created: group='%s'" % details['channel'])

189
190
191
192
193
194
195
196
197
198
        return data_loader


    def _create_data_source(details):
        data_loader = _get_data_loader_for(details)

        filename = os.path.join(cache_root, details['path'] + '.data')

        data_source = CachedDataSource()
        result = data_source.setup(
199
200
201
202
203
204
205
206
            filename=filename,
            prefix=prefix,
            start_index=start_index,
            end_index=end_index,
            unpack=True,
        )

        if not result:
207
            raise IOError("cannot load cache file `%s'" % filename)
208

209
        data_loader.add(name, data_source)
210
211
212
213
214

        logger.debug("Input '%s' added to data loader: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], algorithm.input_map[name], filename))


Philip ABBET's avatar
Philip ABBET committed
215
    for name, details in config['inputs'].items():
216

217
218
        input = None

219
        if details.get('database', None) is not None:
Philip ABBET's avatar
Philip ABBET committed
220
221
222
            if db_access == AccessMode.LOCAL:
                if databases is None:
                    raise IOError("No databases provided")
223

Philip ABBET's avatar
Philip ABBET committed
224
225
226
227
228
                # Retrieve the database
                try:
                    db = databases[details['database']]
                except:
                    raise IOError("Database '%s' not found" % details['database'])
229

Philip ABBET's avatar
Philip ABBET committed
230
231
                # Create of retrieve the database view
                channel = details['channel']
232

233
                if channel not in views:
Philip ABBET's avatar
Philip ABBET committed
234
                    view = db.view(details['protocol'], details['set'])
235
                    view.setup(os.path.join(cache_root, details['path']), pack=False,
236
237
                               start_index=start_index, end_index=end_index)

Philip ABBET's avatar
Philip ABBET committed
238
                    views[channel] = view
239

Philip ABBET's avatar
Philip ABBET committed
240
241
242
243
244
                    logger.debug("Database view '%s/%s/%s' created: group='%s'" % \
                                    (details['database'], details['protocol'], details['set'],
                                     channel))
                else:
                    view = views[channel]
245

246
247
248
249
250
251
252
253
254
255
256
257
258
                data_source = view.data_sources[details['output']]

                if (algorithm.type == Algorithm.LEGACY) or \
                   ((algorithm.type == Algorithm.SEQUENTIAL) and (details['channel'] == config['channel'])):
                    input = Input(name, algorithm.input_map[name], data_source)

                    logger.debug("Input '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
                                    (name, details['channel'], algorithm.input_map[name],
                                    details['database'], details['protocol'], details['set'],
                                    details['output']))
                else:
                    data_loader = _get_data_loader_for(details)
                    data_loader.add(name, data_source)
259

260
261
262
                    logger.debug("DatabaseOutputDataSource '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
                                    (name, channel, algorithm.input_map[name], details['database'],
                                     details['protocol'], details['set'], details['output']))
263

Philip ABBET's avatar
Philip ABBET committed
264
265
            elif db_access == AccessMode.REMOTE:
                if socket is None:
266
267
268
269
270
271
272
273
274
275
276
277
278
                    raise IOError("No socket provided for remote data sources")

                data_source = RemoteDataSource()
                result = data_source.setup(
                    socket=socket,
                    input_name=name,
                    dataformat_name=algorithm.input_map[name],
                    prefix=prefix,
                    unpack=True
                )

                if not result:
                    raise IOError("cannot setup remote data source '%s'" % name)
279
280


281
282
283
284
                if (algorithm.type == Algorithm.LEGACY) or \
                   ((algorithm.type == Algorithm.SEQUENTIAL) and (details['channel'] == config['channel'])):
                    input = Input(name, algorithm.input_map[name], data_source)

Philip ABBET's avatar
Philip ABBET committed
285
                    logger.debug("Input '%s' created: group='%s', dataformat='%s', database-file='%s'" % \
286
                                    (name, details['channel'], algorithm.input_map[name],
Philip ABBET's avatar
Philip ABBET committed
287
                                    details['path']))
288
289
290
291
292
293
                else:
                    data_loader = _get_data_loader_for(details)
                    data_loader.add(name, data_source)

                    logger.debug("RemoteDataSource '%s' created: group='%s', dataformat='%s', connected to a database" % \
                                    (name, details['channel'], algorithm.input_map[name]))
294

295

Philip ABBET's avatar
Philip ABBET committed
296
        elif cache_access == AccessMode.LOCAL:
297

298
299
            if algorithm.type == Algorithm.LEGACY:
                input = _create_local_input(details)
300

301
302
303
304
            elif algorithm.type == Algorithm.SEQUENTIAL:
                if details['channel'] == config['channel']: # synchronized
                    input = _create_local_input(details)
                else:
305
                    _create_data_source(details)
306

Samuel GAIST's avatar
Samuel GAIST committed
307
            else: # Algorithm autonomous types
308
                _create_data_source(details)
309

Philip ABBET's avatar
Philip ABBET committed
310
311
        else:
            continue
312

Philip ABBET's avatar
Philip ABBET committed
313
        # Synchronization bits
314
315
316
317
318
319
        if input is not None:
            group = input_list.group(details['channel'])
            if group is None:
                synchronization_listener = None
                if not no_synchronisation_listeners:
                    synchronization_listener = SynchronizationListener()
320

321
322
323
324
325
326
327
                group = InputGroup(
                          details['channel'],
                          synchronization_listener=synchronization_listener,
                          restricted_access=(details['channel'] == config['channel'])
                        )
                input_list.add(group)
                logger.debug("Group '%s' created" % details['channel'])
328

329
            group.add(input)
330

331
    return (input_list, data_loader_list)
332
333


334
# ----------------------------------------------------------
335
336


337
338
def create_outputs_from_configuration(config, algorithm, prefix, cache_root,
                                      input_list=None, data_loaders=None):
339

Philip ABBET's avatar
Philip ABBET committed
340
    data_sinks = []
341
    output_list = OutputList()
342

Philip ABBET's avatar
Philip ABBET committed
343
344
    # This is used for parallelization purposes
    start_index, end_index = config.get('range', (None, None))
345

Philip ABBET's avatar
Philip ABBET committed
346
347
348
349
350
351
352
    # If the algorithm is an analyser
    if 'result' in config:
        output_config = {
            'result': config['result']
        }
    else:
        output_config = config['outputs']
353
354


Philip ABBET's avatar
Philip ABBET committed
355
    for name, details in output_config.items():
356

357
358
        synchronization_listener = None

Philip ABBET's avatar
Philip ABBET committed
359
360
361
362
363
364
        if 'result' in config:
            dataformat_name = 'analysis:' + algorithm.name
            dataformat = algorithm.result_dataformat()
        else:
            dataformat_name = algorithm.output_map[name]
            dataformat =  algorithm.dataformats[dataformat_name]
365

366
367
        if input_list is not None:
            input_group = input_list.group(config['channel'])
368
369
            if input_group is not None:
                synchronization_listener = input_group.synchronization_listener
370

371
372
373
374
375
376
377
378
379
380
381
382
        path = os.path.join(cache_root, details['path'] + '.data')
        dirname = os.path.dirname(path)
        # Make sure that the directory exists while taking care of race
        # conditions. see: http://stackoverflow.com/questions/273192/check-if-a-directory-exists-and-create-it-if-necessary
        try:
            if (len(dirname) > 0):
                os.makedirs(dirname)
        except OSError as exception:
            if exception.errno != errno.EEXIST:
                raise

        if start_index is None:
383
384
            input_path = None

385
            for k, v in config['inputs'].items():
386
387
388
389
                if v['channel'] != config['channel']:
                    continue

                if 'database' not in v:
390
391
392
                    input_path = os.path.join(cache_root, v['path'] + '.data')
                    break

393
394
395
396
397
398
399
400
401
402
403
404
405
406
            if input_path is not None:
                (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \
                        getAllFilenames(input_path)

                end_indices = [ int(x.split('.')[-2]) for x in indices_filenames ]
                end_indices.sort()

                start_index = 0
                end_index = end_indices[-1]

            else:
                for k, v in config['inputs'].items():
                    if v['channel'] != config['channel']:
                        continue
407

408
                    start_index = 0
409

410
411
412
413
414
415
                    if (input_list is not None) and (input_list[k] is not None):
                        end_index = input_list[k].data_source.last_data_index()
                        break
                    elif data_loaders is not None:
                        end_index = data_loaders.main_loader.data_index_end
                        break
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434

        data_sink = CachedDataSink()
        data_sinks.append(data_sink)

        status = data_sink.setup(
            filename=path,
            dataformat=dataformat,
            start_index=start_index,
            end_index=end_index,
            encoding='binary'
        )

        if not status:
            raise IOError("Cannot create cache sink '%s'" % details['path'])

        output_list.add(Output(name, data_sink,
                               synchronization_listener=synchronization_listener,
                               force_start_index=start_index)
        )
435

436
437
438
        if 'result' not in config:
            logger.debug("Output '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                            (name, details['channel'], dataformat_name, path))
Philip ABBET's avatar
Philip ABBET committed
439
        else:
440
441
            logger.debug("Output '%s' created: dataformat='%s', filename='%s'" % \
                            (name, dataformat_name, path))
442

Philip ABBET's avatar
Philip ABBET committed
443
    return (output_list, data_sinks)