helpers.py 5.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.backend.python module of the BEAT platform.   #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################


29
30
31
32
33
import os

import logging
logger = logging.getLogger(__name__)

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from . import data
from . import inputs
from . import outputs



def convert_experiment_configuration_to_container(config, proxy_mode):
  data = {
    'proxy_mode': proxy_mode,
    'algorithm': config['algorithm'],
    'parameters': config['parameters'],
    'channel': config['channel'],
  }

  if 'range' in config:
    data['range'] = config['range']

  data['inputs'] = \
52
        dict([(k, { 'channel': v['channel'], 'path': v['path'], 'database': v.has_key('database') }) for k,v in config['inputs'].items()])
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

  if 'outputs' in config:
    data['outputs'] = \
          dict([(k, { 'channel': v['channel'], 'path': v['path'] }) for k,v in config['outputs'].items()])
  else:
    data['result'] = { 'channel': config['channel'], 'path': config['result']['path'] }

  return data


#----------------------------------------------------------


class CacheAccess:
  NONE   = 0
  LOCAL  = 1
  REMOTE = 2


def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
                                     cache_access=CacheAccess.NONE, unpack=True,
                                     socket=None):

  data_sources = []
  input_list = inputs.InputList()

  # This is used for parallelization purposes
  start_index, end_index = config.get('range', (None, None))

  for name, details in config['inputs'].items():

84
    if details.get('database', False):
85
86
87
88
89
90
      if socket is None:
        raise IOError("No socket provided for remote inputs")

      input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
                                 socket, unpack=unpack)

91
92
93
94
      logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s', connected to a database" % \
                      (name, details['channel'], algorithm.input_map[name]))

    elif cache_access == CacheAccess.LOCAL:
95
96
97
      data_source = data.CachedDataSource()
      data_sources.append(data_source)

98
99
      filename = os.path.join(cache_root, details['path'] + '.data')

100
101
      if details['channel'] == config['channel']: # synchronized
        status = data_source.setup(
102
                  filename=filename,
103
104
105
106
107
108
109
                  prefix=prefix,
                  force_start_index=start_index,
                  force_end_index=end_index,
                  unpack=unpack,
                 )
      else:
        status = data_source.setup(
110
                  filename=filename,
111
112
113
114
115
116
117
                  prefix=prefix,
                  unpack=unpack,
                 )

      if not status:
        raise IOError("cannot load cache file `%s'" % details['path'])

118
119
120
121
122
123
124
125
126
127
128
129
130
131
      input = inputs.Input(name, algorithm.input_map[name], data_source)

      logger.debug("Input '%s' created: group='%s', dataformat='%s', filename='%s'" % \
                      (name, details['channel'], algorithm.input_map[name], filename))

    elif cache_access == CacheAccess.REMOTE:
      if socket is None:
        raise IOError("No socket provided for remote inputs")

      input = inputs.RemoteInput(name, algorithm.dataformats[algorithm.input_map[name]],
                                 socket, unpack=unpack)

      logger.debug("RemoteInput '%s' created: group='%s', dataformat='%s'" % \
                      (name, details['channel'], algorithm.input_map[name]))
132
133
134
135
136
137
138
139
140
141
142
143
144

    else:
      continue

    # Synchronization bits
    group = input_list.group(details['channel'])
    if group is None:
      group = inputs.InputGroup(
                details['channel'],
                synchronization_listener=outputs.SynchronizationListener(),
                restricted_access=(details['channel'] == config['channel'])
              )
      input_list.add(group)
145
      logger.debug("Group '%s' created" % details['channel'])
146
147
148
149

    group.add(input)

  return (input_list, data_sources)