helpers.py 12.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.backend.python module of the BEAT platform.   #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################


29
import os
30
import errno
31
32
33
34

import logging
logger = logging.getLogger(__name__)

35
36
37
38
39
40
41
42
43
44
45
46
from .data import CachedDataSource
from .data import CachedDataSink
from .data import getAllFilenames
from .data_loaders import DataLoaderList
from .data_loaders import DataLoader
from .inputs import InputList
from .inputs import Input
from .inputs import InputGroup
from .outputs import SynchronizationListener
from .outputs import OutputList
from .outputs import Output
from .algorithm import Algorithm
47
48


49
50
#----------------------------------------------------------

51

52
def convert_experiment_configuration_to_container(config):
Philip ABBET's avatar
Philip ABBET committed
53
    data = {
54
55
56
57
        'algorithm': config['algorithm'],
        'parameters': config['parameters'],
        'channel': config['channel'],
        'uid': os.getuid(),
Philip ABBET's avatar
Philip ABBET committed
58
    }
59

Philip ABBET's avatar
Philip ABBET committed
60
61
    if 'range' in config:
        data['range'] = config['range']
62

Philip ABBET's avatar
Philip ABBET committed
63
64
    data['inputs'] = \
          dict([(k, { 'channel': v['channel'], 'path': v['path'], 'database': v.has_key('database') }) for k,v in config['inputs'].items()])
65

Philip ABBET's avatar
Philip ABBET committed
66
67
68
69
70
    if 'outputs' in config:
        data['outputs'] = \
              dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['outputs'].items()])
    else:
        data['result'] = { 'channel': config['channel'], 'path': config['result']['path'] }
71

Philip ABBET's avatar
Philip ABBET committed
72
    return data
73
74
75
76
77


#----------------------------------------------------------


78
class AccessMode:
Philip ABBET's avatar
Philip ABBET committed
79
80
81
    NONE   = 0
    LOCAL  = 1
    REMOTE = 2
82
83
84


def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
85
86
87
                                     cache_access=AccessMode.NONE,
                                     db_access=AccessMode.NONE,
                                     unpack=True, socket=None,
88
89
                                     databases=None,
                                     no_synchronisation_listeners=False):
90

Philip ABBET's avatar
Philip ABBET committed
91
92
    data_sources = []
    views = {}
93
94
    input_list = InputList()
    data_loader_list = DataLoaderList()
95

Philip ABBET's avatar
Philip ABBET committed
96
97
    # This is used for parallelization purposes
    start_index, end_index = config.get('range', (None, None))
98

99
100

    def _create_local_input(details):
101
        data_source = CachedDataSource()
102
103
104
105
106
107
108
109
        data_sources.append(data_source)

        filename = os.path.join(cache_root, details['path'] + '.data')

        if details['channel'] == config['channel']: # synchronized
            status = data_source.setup(
                      filename=filename,
                      prefix=prefix,
110
111
                      start_index=start_index,
                      end_index=end_index,
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
                      unpack=True,
                     )
        else:
            status = data_source.setup(
                      filename=filename,
                      prefix=prefix,
                      unpack=True,
                     )

        if not status:
            raise IOError("cannot load cache file `%s'" % details['path'])

        input = Input(name, algorithm.input_map[name], data_source)

        logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], algorithm.input_map[name], filename))

        return input


132
    def _get_data_loader_for(details):
133
134
135
136
137
138
139
        data_loader = data_loader_list[details['channel']]
        if data_loader is None:
            data_loader = DataLoader(details['channel'])
            data_loader_list.add(data_loader)

            logger.debug("Data loader created: group='%s'" % details['channel'])

140
141
142
143
144
145
146
147
148
149
        return data_loader


    def _create_data_source(details):
        data_loader = _get_data_loader_for(details)

        filename = os.path.join(cache_root, details['path'] + '.data')

        data_source = CachedDataSource()
        result = data_source.setup(
150
151
152
153
154
155
156
157
158
159
            filename=filename,
            prefix=prefix,
            start_index=start_index,
            end_index=end_index,
            unpack=True,
        )

        if not result:
            raise IOError("cannot load cache file `%s'" % details['path'])

160
        data_loader.add(name, data_source)
161
162
163
164
165

        logger.debug("Input '%s' added to data loader: group='%s', dataformat='%s', filename='%s'" % \
                        (name, details['channel'], algorithm.input_map[name], filename))


Philip ABBET's avatar
Philip ABBET committed
166
    for name, details in config['inputs'].items():
167

168
169
        input = None

Philip ABBET's avatar
Philip ABBET committed
170
171
172
173
        if details.get('database', False):
            if db_access == AccessMode.LOCAL:
                if databases is None:
                    raise IOError("No databases provided")
174

Philip ABBET's avatar
Philip ABBET committed
175
176
177
178
179
                # Retrieve the database
                try:
                    db = databases[details['database']]
                except:
                    raise IOError("Database '%s' not found" % details['database'])
180

Philip ABBET's avatar
Philip ABBET committed
181
182
                # Create of retrieve the database view
                channel = details['channel']
183

Philip ABBET's avatar
Philip ABBET committed
184
185
                if not views.has_key(channel):
                    view = db.view(details['protocol'], details['set'])
186
                    print details
Philip ABBET's avatar
Philip ABBET committed
187
188
                    view.setup()
                    views[channel] = view
189

Philip ABBET's avatar
Philip ABBET committed
190
191
192
193
194
                    logger.debug("Database view '%s/%s/%s' created: group='%s'" % \
                                    (details['database'], details['protocol'], details['set'],
                                     channel))
                else:
                    view = views[channel]
195

196
197
                data_loader = _get_data_loader_for(details)
                data_loader.add(name, view.data_sources[details['output']])
198

199
                logger.debug("DatabaseOutputDataSource '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
Philip ABBET's avatar
Philip ABBET committed
200
201
                                (name, channel, algorithm.input_map[name], details['database'],
                                 details['protocol'], details['set'], details['output']))
202

Philip ABBET's avatar
Philip ABBET committed
203
204
            elif db_access == AccessMode.REMOTE:
                if socket is None:
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
                    raise IOError("No socket provided for remote data sources")

                data_loader = _get_data_loader_for(details)

                data_source = RemoteDataSource()
                result = data_source.setup(
                    socket=socket,
                    input_name=name,
                    dataformat_name=algorithm.input_map[name],
                    prefix=prefix,
                    unpack=True
                )

                if not result:
                    raise IOError("cannot setup remote data source '%s'" % name)
220

221
                data_loader.add(name, data_source)
222

223
                logger.debug("RemoteDataSource '%s' created: group='%s', dataformat='%s', connected to a database" % \
Philip ABBET's avatar
Philip ABBET committed
224
                                (name, details['channel'], algorithm.input_map[name]))
225

226

Philip ABBET's avatar
Philip ABBET committed
227
        elif cache_access == AccessMode.LOCAL:
228

229
230
            if algorithm.type == Algorithm.LEGACY:
                input = _create_local_input(details)
231

232
233
234
235
            elif algorithm.type == Algorithm.SEQUENTIAL:
                if details['channel'] == config['channel']: # synchronized
                    input = _create_local_input(details)
                else:
236
                    _create_data_source(details)
237
238

            elif algorithm.type == Algorithm.AUTONOMOUS:
239
                _create_data_source(details)
240

Philip ABBET's avatar
Philip ABBET committed
241
242
        else:
            continue
243

Philip ABBET's avatar
Philip ABBET committed
244
        # Synchronization bits
245
246
247
248
249
250
        if input is not None:
            group = input_list.group(details['channel'])
            if group is None:
                synchronization_listener = None
                if not no_synchronisation_listeners:
                    synchronization_listener = SynchronizationListener()
251

252
253
254
255
256
257
258
                group = InputGroup(
                          details['channel'],
                          synchronization_listener=synchronization_listener,
                          restricted_access=(details['channel'] == config['channel'])
                        )
                input_list.add(group)
                logger.debug("Group '%s' created" % details['channel'])
259

260
            group.add(input)
261

262
    return (input_list, data_loader_list, data_sources)
263
264
265
266
267


#----------------------------------------------------------


268
def create_outputs_from_configuration(config, algorithm, prefix, cache_root, input_list):
269

Philip ABBET's avatar
Philip ABBET committed
270
    data_sinks = []
271
    output_list = OutputList()
272

Philip ABBET's avatar
Philip ABBET committed
273
274
    # This is used for parallelization purposes
    start_index, end_index = config.get('range', (None, None))
275

Philip ABBET's avatar
Philip ABBET committed
276
277
278
279
280
281
282
    # If the algorithm is an analyser
    if 'result' in config:
        output_config = {
            'result': config['result']
        }
    else:
        output_config = config['outputs']
283
284


Philip ABBET's avatar
Philip ABBET committed
285
    for name, details in output_config.items():
286

287
288
        synchronization_listener = None

Philip ABBET's avatar
Philip ABBET committed
289
290
291
292
293
294
        if 'result' in config:
            dataformat_name = 'analysis:' + algorithm.name
            dataformat = algorithm.result_dataformat()
        else:
            dataformat_name = algorithm.output_map[name]
            dataformat =  algorithm.dataformats[dataformat_name]
295

296
297
298
            input_group = input_list.group(details['channel'])
            if input_group is not None:
                synchronization_listener = input_group.synchronization_listener
299

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
        path = os.path.join(cache_root, details['path'] + '.data')
        dirname = os.path.dirname(path)
        # Make sure that the directory exists while taking care of race
        # conditions. see: http://stackoverflow.com/questions/273192/check-if-a-directory-exists-and-create-it-if-necessary
        try:
            if (len(dirname) > 0):
                os.makedirs(dirname)
        except OSError as exception:
            if exception.errno != errno.EEXIST:
                raise

        if start_index is None:
            for k, v in config['inputs'].items():
                if v['channel'] == config['channel']:
                    input_path = os.path.join(cache_root, v['path'] + '.data')
                    break

            (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \
                    getAllFilenames(input_path)

            end_indices = [ int(x.split('.')[-2]) for x in indices_filenames ]
            end_indices.sort()

            start_index = 0
            end_index = end_indices[-1]

        data_sink = CachedDataSink()
        data_sinks.append(data_sink)

        status = data_sink.setup(
            filename=path,
            dataformat=dataformat,
            start_index=start_index,
            end_index=end_index,
            encoding='binary'
        )

        if not status:
            raise IOError("Cannot create cache sink '%s'" % details['path'])

        output_list.add(Output(name, data_sink,
                               synchronization_listener=synchronization_listener,
                               force_start_index=start_index)
        )
344

345
346
347
        if 'result' not in config:
            logger.debug("Output '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                            (name, details['channel'], dataformat_name, path))
Philip ABBET's avatar
Philip ABBET committed
348
        else:
349
350
            logger.debug("Output '%s' created: dataformat='%s', filename='%s'" % \
                            (name, dataformat_name, path))
351

Philip ABBET's avatar
Philip ABBET committed
352
    return (output_list, data_sinks)