helpers.py 14.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.backend.python module of the BEAT platform.   #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################


29
import os
30
import errno
31
32
33
34

import logging
logger = logging.getLogger(__name__)

35
from .data import CachedDataSource
36
from .data import RemoteDataSource
37
38
39
40
41
42
43
44
45
46
47
from .data import CachedDataSink
from .data import getAllFilenames
from .data_loaders import DataLoaderList
from .data_loaders import DataLoader
from .inputs import InputList
from .inputs import Input
from .inputs import InputGroup
from .outputs import SynchronizationListener
from .outputs import OutputList
from .outputs import Output
from .algorithm import Algorithm
48
49


50
51
#----------------------------------------------------------

52

53
def convert_experiment_configuration_to_container(config):
Philip ABBET's avatar
Philip ABBET committed
54
    data = {
55
56
57
58
        'algorithm': config['algorithm'],
        'parameters': config['parameters'],
        'channel': config['channel'],
        'uid': os.getuid(),
Philip ABBET's avatar
Philip ABBET committed
59
    }
60

Philip ABBET's avatar
Philip ABBET committed
61
62
    if 'range' in config:
        data['range'] = config['range']
63

Philip ABBET's avatar
Philip ABBET committed
64
    data['inputs'] = \
65
          dict([(k, { 'channel': v['channel'], 'path': v['path'], 'database': 'database' in v }) for k,v in config['inputs'].items()])
66

Philip ABBET's avatar
Philip ABBET committed
67
68
69
70
71
    if 'outputs' in config:
        data['outputs'] = \
              dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['outputs'].items()])
    else:
        data['result'] = { 'channel': config['channel'], 'path': config['result']['path'] }
72

Philip ABBET's avatar
Philip ABBET committed
73
    return data
74
75
76
77
78


#----------------------------------------------------------


79
class AccessMode:
Philip ABBET's avatar
Philip ABBET committed
80
81
82
    NONE   = 0
    LOCAL  = 1
    REMOTE = 2
83
84
85


def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
86
87
88
                                     cache_access=AccessMode.NONE,
                                     db_access=AccessMode.NONE,
                                     unpack=True, socket=None,
89
90
                                     databases=None,
                                     no_synchronisation_listeners=False):
91

Philip ABBET's avatar
Philip ABBET committed
92
93
    data_sources = []
    views = {}
94
95
    input_list = InputList()
    data_loader_list = DataLoaderList()
96

Philip ABBET's avatar
Philip ABBET committed
97
98
    # This is used for parallelization purposes
    start_index, end_index = config.get('range', (None, None))
99

100
101

    def _create_local_input(details):
102
        data_source = CachedDataSource()
103
104
105
106
107
108
109

        filename = os.path.join(cache_root, details['path'] + '.data')

        if details['channel'] == config['channel']: # synchronized
            status = data_source.setup(
                      filename=filename,
                      prefix=prefix,
110
111
                      start_index=start_index,
                      end_index=end_index,
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
                      unpack=True,
                     )
        else:
            status = data_source.setup(
                      filename=filename,
                      prefix=prefix,
                      unpack=True,
                     )

        if not status:
            raise IOError("cannot load cache file `%s'" % details['path'])

        input = Input(name, algorithm.input_map[name], data_source)

        logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], algorithm.input_map[name], filename))

        return input


132
    def _get_data_loader_for(details):
133
134
135
136
137
138
139
        data_loader = data_loader_list[details['channel']]
        if data_loader is None:
            data_loader = DataLoader(details['channel'])
            data_loader_list.add(data_loader)

            logger.debug("Data loader created: group='%s'" % details['channel'])

140
141
142
143
144
145
146
147
148
149
        return data_loader


    def _create_data_source(details):
        data_loader = _get_data_loader_for(details)

        filename = os.path.join(cache_root, details['path'] + '.data')

        data_source = CachedDataSource()
        result = data_source.setup(
150
151
152
153
154
155
156
157
158
159
            filename=filename,
            prefix=prefix,
            start_index=start_index,
            end_index=end_index,
            unpack=True,
        )

        if not result:
            raise IOError("cannot load cache file `%s'" % details['path'])

160
        data_loader.add(name, data_source)
161
162
163
164
165

        logger.debug("Input '%s' added to data loader: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], algorithm.input_map[name], filename))


Philip ABBET's avatar
Philip ABBET committed
166
    for name, details in config['inputs'].items():
167

168
169
        input = None

Philip ABBET's avatar
Philip ABBET committed
170
171
172
173
        if details.get('database', False):
            if db_access == AccessMode.LOCAL:
                if databases is None:
                    raise IOError("No databases provided")
174

Philip ABBET's avatar
Philip ABBET committed
175
176
177
178
179
                # Retrieve the database
                try:
                    db = databases[details['database']]
                except:
                    raise IOError("Database '%s' not found" % details['database'])
180

Philip ABBET's avatar
Philip ABBET committed
181
182
                # Create of retrieve the database view
                channel = details['channel']
183

184
                if channel not in views:
Philip ABBET's avatar
Philip ABBET committed
185
                    view = db.view(details['protocol'], details['set'])
186
                    view.setup(os.path.join(cache_root, details['path']), pack=False,
187
188
                               start_index=start_index, end_index=end_index)

Philip ABBET's avatar
Philip ABBET committed
189
                    views[channel] = view
190

Philip ABBET's avatar
Philip ABBET committed
191
192
193
194
195
                    logger.debug("Database view '%s/%s/%s' created: group='%s'" % \
                                    (details['database'], details['protocol'], details['set'],
                                     channel))
                else:
                    view = views[channel]
196

197
198
199
200
201
202
203
204
205
206
207
208
209
                data_source = view.data_sources[details['output']]

                if (algorithm.type == Algorithm.LEGACY) or \
                   ((algorithm.type == Algorithm.SEQUENTIAL) and (details['channel'] == config['channel'])):
                    input = Input(name, algorithm.input_map[name], data_source)

                    logger.debug("Input '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
                                    (name, details['channel'], algorithm.input_map[name],
                                    details['database'], details['protocol'], details['set'],
                                    details['output']))
                else:
                    data_loader = _get_data_loader_for(details)
                    data_loader.add(name, data_source)
210

211
212
213
                    logger.debug("DatabaseOutputDataSource '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
                                    (name, channel, algorithm.input_map[name], details['database'],
                                     details['protocol'], details['set'], details['output']))
214

Philip ABBET's avatar
Philip ABBET committed
215
216
            elif db_access == AccessMode.REMOTE:
                if socket is None:
217
218
219
220
221
222
223
224
225
226
227
228
229
                    raise IOError("No socket provided for remote data sources")

                data_source = RemoteDataSource()
                result = data_source.setup(
                    socket=socket,
                    input_name=name,
                    dataformat_name=algorithm.input_map[name],
                    prefix=prefix,
                    unpack=True
                )

                if not result:
                    raise IOError("cannot setup remote data source '%s'" % name)
230
231


232
233
234
235
                if (algorithm.type == Algorithm.LEGACY) or \
                   ((algorithm.type == Algorithm.SEQUENTIAL) and (details['channel'] == config['channel'])):
                    input = Input(name, algorithm.input_map[name], data_source)

Philip ABBET's avatar
Philip ABBET committed
236
                    logger.debug("Input '%s' created: group='%s', dataformat='%s', database-file='%s'" % \
237
                                    (name, details['channel'], algorithm.input_map[name],
Philip ABBET's avatar
Philip ABBET committed
238
                                    details['path']))
239
240
241
242
243
244
                else:
                    data_loader = _get_data_loader_for(details)
                    data_loader.add(name, data_source)

                    logger.debug("RemoteDataSource '%s' created: group='%s', dataformat='%s', connected to a database" % \
                                    (name, details['channel'], algorithm.input_map[name]))
245

246

Philip ABBET's avatar
Philip ABBET committed
247
        elif cache_access == AccessMode.LOCAL:
248

249
250
            if algorithm.type == Algorithm.LEGACY:
                input = _create_local_input(details)
251

252
253
254
255
            elif algorithm.type == Algorithm.SEQUENTIAL:
                if details['channel'] == config['channel']: # synchronized
                    input = _create_local_input(details)
                else:
256
                    _create_data_source(details)
257
258

            elif algorithm.type == Algorithm.AUTONOMOUS:
259
                _create_data_source(details)
260

Philip ABBET's avatar
Philip ABBET committed
261
262
        else:
            continue
263

Philip ABBET's avatar
Philip ABBET committed
264
        # Synchronization bits
265
266
267
268
269
270
        if input is not None:
            group = input_list.group(details['channel'])
            if group is None:
                synchronization_listener = None
                if not no_synchronisation_listeners:
                    synchronization_listener = SynchronizationListener()
271

272
273
274
275
276
277
278
                group = InputGroup(
                          details['channel'],
                          synchronization_listener=synchronization_listener,
                          restricted_access=(details['channel'] == config['channel'])
                        )
                input_list.add(group)
                logger.debug("Group '%s' created" % details['channel'])
279

280
            group.add(input)
281

282
    return (input_list, data_loader_list)
283
284
285
286
287


#----------------------------------------------------------


288
289
def create_outputs_from_configuration(config, algorithm, prefix, cache_root,
                                      input_list=None, data_loaders=None):
290

Philip ABBET's avatar
Philip ABBET committed
291
    data_sinks = []
292
    output_list = OutputList()
293

Philip ABBET's avatar
Philip ABBET committed
294
295
    # This is used for parallelization purposes
    start_index, end_index = config.get('range', (None, None))
296

Philip ABBET's avatar
Philip ABBET committed
297
298
299
300
301
302
303
    # If the algorithm is an analyser
    if 'result' in config:
        output_config = {
            'result': config['result']
        }
    else:
        output_config = config['outputs']
304
305


Philip ABBET's avatar
Philip ABBET committed
306
    for name, details in output_config.items():
307

308
309
        synchronization_listener = None

Philip ABBET's avatar
Philip ABBET committed
310
311
312
313
314
315
        if 'result' in config:
            dataformat_name = 'analysis:' + algorithm.name
            dataformat = algorithm.result_dataformat()
        else:
            dataformat_name = algorithm.output_map[name]
            dataformat =  algorithm.dataformats[dataformat_name]
316

317
318
        if input_list is not None:
            input_group = input_list.group(config['channel'])
319
320
            if input_group is not None:
                synchronization_listener = input_group.synchronization_listener
321

322
323
324
325
326
327
328
329
330
331
332
333
        path = os.path.join(cache_root, details['path'] + '.data')
        dirname = os.path.dirname(path)
        # Make sure that the directory exists while taking care of race
        # conditions. see: http://stackoverflow.com/questions/273192/check-if-a-directory-exists-and-create-it-if-necessary
        try:
            if (len(dirname) > 0):
                os.makedirs(dirname)
        except OSError as exception:
            if exception.errno != errno.EEXIST:
                raise

        if start_index is None:
334
335
            input_path = None

336
            for k, v in config['inputs'].items():
337
338
339
340
                if v['channel'] != config['channel']:
                    continue

                if 'database' not in v:
341
342
343
                    input_path = os.path.join(cache_root, v['path'] + '.data')
                    break

344
345
346
347
348
349
350
351
352
353
354
355
356
357
            if input_path is not None:
                (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \
                        getAllFilenames(input_path)

                end_indices = [ int(x.split('.')[-2]) for x in indices_filenames ]
                end_indices.sort()

                start_index = 0
                end_index = end_indices[-1]

            else:
                for k, v in config['inputs'].items():
                    if v['channel'] != config['channel']:
                        continue
358

359
                    start_index = 0
360

361
362
363
364
365
366
                    if (input_list is not None) and (input_list[k] is not None):
                        end_index = input_list[k].data_source.last_data_index()
                        break
                    elif data_loaders is not None:
                        end_index = data_loaders.main_loader.data_index_end
                        break
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385

        data_sink = CachedDataSink()
        data_sinks.append(data_sink)

        status = data_sink.setup(
            filename=path,
            dataformat=dataformat,
            start_index=start_index,
            end_index=end_index,
            encoding='binary'
        )

        if not status:
            raise IOError("Cannot create cache sink '%s'" % details['path'])

        output_list.add(Output(name, data_sink,
                               synchronization_listener=synchronization_listener,
                               force_start_index=start_index)
        )
386

387
388
389
        if 'result' not in config:
            logger.debug("Output '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                            (name, details['channel'], dataformat_name, path))
Philip ABBET's avatar
Philip ABBET committed
390
        else:
391
392
            logger.debug("Output '%s' created: dataformat='%s', filename='%s'" % \
                            (name, dataformat_name, path))
393

Philip ABBET's avatar
Philip ABBET committed
394
    return (output_list, data_sinks)