Commit df468c7b authored by Philip ABBET's avatar Philip ABBET
Browse files

Add more tests for the helper functions

parent 1819d452
......@@ -161,7 +161,7 @@ class DBExecutor(object):
start_index, end_index = (None, None)
view.setup(os.path.join(cache_root, details['path']),
start_index=start_index, end_index=end_index)
start_index=start_index, end_index=end_index)
self.views[key] = view
......
......@@ -101,7 +101,7 @@ class Executor(object):
if self.algorithm.type == Algorithm.LEGACY:
# Loads algorithm inputs
(self.input_list, self.data_loaders, _) = create_inputs_from_configuration(
(self.input_list, self.data_loaders) = create_inputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root,
cache_access=AccessMode.LOCAL, db_access=AccessMode.REMOTE,
socket=self.socket
......@@ -113,7 +113,7 @@ class Executor(object):
)
else:
(self.input_list, self.data_loaders, _) = create_inputs_from_configuration(
(self.input_list, self.data_loaders) = create_inputs_from_configuration(
self.data, self.algorithm, self.prefix, cache_root,
cache_access=AccessMode.LOCAL, db_access=AccessMode.REMOTE
)
......
......@@ -33,6 +33,7 @@ import logging
logger = logging.getLogger(__name__)
from .data import CachedDataSource
from .data import RemoteDataSource
from .data import CachedDataSink
from .data import getAllFilenames
from .data_loaders import DataLoaderList
......@@ -99,7 +100,6 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
def _create_local_input(details):
data_source = CachedDataSource()
data_sources.append(data_source)
filename = os.path.join(cache_root, details['path'] + '.data')
......@@ -183,8 +183,9 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
if not views.has_key(channel):
view = db.view(details['protocol'], details['set'])
print details
view.setup()
view.setup(os.path.join(cache_root, details['path']),
start_index=start_index, end_index=end_index)
views[channel] = view
logger.debug("Database view '%s/%s/%s' created: group='%s'" % \
......@@ -193,19 +194,28 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
else:
view = views[channel]
data_loader = _get_data_loader_for(details)
data_loader.add(name, view.data_sources[details['output']])
data_source = view.data_sources[details['output']]
if (algorithm.type == Algorithm.LEGACY) or \
((algorithm.type == Algorithm.SEQUENTIAL) and (details['channel'] == config['channel'])):
input = Input(name, algorithm.input_map[name], data_source)
logger.debug("Input '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
(name, details['channel'], algorithm.input_map[name],
details['database'], details['protocol'], details['set'],
details['output']))
else:
data_loader = _get_data_loader_for(details)
data_loader.add(name, data_source)
logger.debug("DatabaseOutputDataSource '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
(name, channel, algorithm.input_map[name], details['database'],
details['protocol'], details['set'], details['output']))
logger.debug("DatabaseOutputDataSource '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
(name, channel, algorithm.input_map[name], details['database'],
details['protocol'], details['set'], details['output']))
elif db_access == AccessMode.REMOTE:
if socket is None:
raise IOError("No socket provided for remote data sources")
data_loader = _get_data_loader_for(details)
data_source = RemoteDataSource()
result = data_source.setup(
socket=socket,
......@@ -218,10 +228,21 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
if not result:
raise IOError("cannot setup remote data source '%s'" % name)
data_loader.add(name, data_source)
logger.debug("RemoteDataSource '%s' created: group='%s', dataformat='%s', connected to a database" % \
(name, details['channel'], algorithm.input_map[name]))
if (algorithm.type == Algorithm.LEGACY) or \
((algorithm.type == Algorithm.SEQUENTIAL) and (details['channel'] == config['channel'])):
input = Input(name, algorithm.input_map[name], data_source)
logger.debug("Input '%s' created: group='%s', dataformat='%s', database-output='%s/%s/%s:%s'" % \
(name, details['channel'], algorithm.input_map[name],
details['database'], details['protocol'], details['set'],
details['output']))
else:
data_loader = _get_data_loader_for(details)
data_loader.add(name, data_source)
logger.debug("RemoteDataSource '%s' created: group='%s', dataformat='%s', connected to a database" % \
(name, details['channel'], algorithm.input_map[name]))
elif cache_access == AccessMode.LOCAL:
......@@ -259,7 +280,7 @@ def create_inputs_from_configuration(config, algorithm, prefix, cache_root,
group.add(input)
return (input_list, data_loader_list, data_sources)
return (input_list, data_loader_list)
#----------------------------------------------------------
......
......@@ -43,8 +43,6 @@ from ..dataformat import DataFormat
from ..data import CachedDataSink
from ..data import CachedDataSource
from ..helpers import convert_experiment_configuration_to_container
from ..helpers import create_inputs_from_configuration
from ..helpers import create_outputs_from_configuration
from ..helpers import AccessMode
from . import prefix
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.backend.python module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
# Tests for experiment execution
import os
import logging
logger = logging.getLogger(__name__)
import unittest
import zmq
import tempfile
import shutil
from ..dbexecution import DBExecutor
from ..database import Database
from ..data_loaders import DataLoader
from ..data import RemoteDataSource
from ..hash import hashDataset
from ..hash import toPath
from ..algorithm import Algorithm
from ..helpers import create_inputs_from_configuration
from ..helpers import AccessMode
from ..message_handler import MessageHandler
from . import prefix
#----------------------------------------------------------
DB_VIEW_HASH = hashDataset('integers_db/1', 'double', 'double')
DB_INDEX_PATH = toPath(DB_VIEW_HASH, suffix='.db')
CONFIGURATION_DB_LEGACY = {
'queue': 'queue',
'algorithm': 'legacy/echo/1',
'nb_slots': 1,
'channel': 'integers',
'parameters': {
},
'environment': {
'name': 'Python 2.7',
'version': '1.2.0'
},
'inputs': {
'in': {
'database': 'integers_db/1',
'protocol': 'double',
'set': 'double',
'output': 'a',
'endpoint': 'in',
'channel': 'integers',
'path': DB_INDEX_PATH,
'hash': DB_VIEW_HASH,
},
},
'outputs': {
'out': {
'endpoint': 'out',
'channel': 'integers',
'path': '20/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
'hash': '2061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
}
}
}
CONFIGURATION_DB_SEQUENTIAL = {
'queue': 'queue',
'algorithm': 'sequential/echo/1',
'nb_slots': 1,
'channel': 'integers',
'parameters': {
},
'environment': {
'name': 'Python 2.7',
'version': '1.2.0'
},
'inputs': {
'in': {
'database': 'integers_db/1',
'protocol': 'double',
'set': 'double',
'output': 'a',
'endpoint': 'in',
'channel': 'integers',
'path': DB_INDEX_PATH,
'hash': DB_VIEW_HASH,
},
},
'outputs': {
'out': {
'endpoint': 'out',
'channel': 'integers',
'path': '20/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
'hash': '2061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
}
}
}
CONFIGURATION_DB_AUTONOMOUS = {
'queue': 'queue',
'algorithm': 'autonomous/echo/1',
'nb_slots': 1,
'channel': 'integers',
'parameters': {
},
'environment': {
'name': 'Python 2.7',
'version': '1.2.0'
},
'inputs': {
'in': {
'database': 'integers_db/1',
'protocol': 'double',
'set': 'double',
'output': 'a',
'endpoint': 'in',
'channel': 'integers',
'path': DB_INDEX_PATH,
'hash': DB_VIEW_HASH,
},
},
'outputs': {
'out': {
'endpoint': 'out',
'channel': 'integers',
'path': '20/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
'hash': '2061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
}
}
}
#----------------------------------------------------------
class TestCreateInputsFromConfiguration_RemoteDatabase(unittest.TestCase):
def setUp(self, remote=True):
self.remote = remote
self.cache_root = tempfile.mkdtemp(prefix=__name__)
database = Database(prefix, 'integers_db/1')
view = database.view('double', 'double')
view.index(os.path.join(self.cache_root, DB_INDEX_PATH))
self.databases = {}
self.databases['integers_db/1'] = database
if remote:
view.setup(os.path.join(self.cache_root, DB_INDEX_PATH))
data_sources = {
'in': view.data_sources['a'],
}
self.message_handler = MessageHandler('127.0.0.1', data_sources=data_sources)
self.message_handler.start()
self.zmq_context = zmq.Context()
self.socket = self.zmq_context.socket(zmq.PAIR)
self.socket.connect(self.message_handler.address)
else:
self.message_handler = None
self.socket = None
def tearDown(self):
if self.message_handler is not None:
self.message_handler.kill()
self.message_handler.join()
self.message_handler.destroy()
self.message_handler = None
if self.socket is not None:
self.socket.setsockopt(zmq.LINGER, 0)
self.socket.close()
self.zmq_context.destroy()
self.socket = None
self.zmq_context = None
shutil.rmtree(self.cache_root)
def test_legacy_algorithm(self):
algorithm = Algorithm(prefix, CONFIGURATION_DB_LEGACY['algorithm'])
runner = algorithm.runner()
if self.remote:
db_access = db_access=AccessMode.REMOTE
else:
db_access = db_access=AccessMode.LOCAL
(input_list, data_loader_list) = \
create_inputs_from_configuration(CONFIGURATION_DB_LEGACY, algorithm,
prefix, self.cache_root,
cache_access=AccessMode.NONE,
db_access=db_access,
unpack=True,
databases=self.databases,
socket=self.socket,
no_synchronisation_listeners=False
)
self.assertEqual(len(input_list), 1)
self.assertEqual(len(data_loader_list), 0)
def test_sequential_algorithm(self):
algorithm = Algorithm(prefix, CONFIGURATION_DB_SEQUENTIAL['algorithm'])
runner = algorithm.runner()
if self.remote:
db_access = db_access=AccessMode.REMOTE
else:
db_access = db_access=AccessMode.LOCAL
(input_list, data_loader_list) = \
create_inputs_from_configuration(CONFIGURATION_DB_SEQUENTIAL, algorithm,
prefix, self.cache_root,
cache_access=AccessMode.NONE,
db_access=db_access,
unpack=True,
databases=self.databases,
socket=self.socket,
no_synchronisation_listeners=False
)
self.assertEqual(len(input_list), 1)
self.assertEqual(len(data_loader_list), 0)
def test_autonomous_algorithm(self):
algorithm = Algorithm(prefix, CONFIGURATION_DB_AUTONOMOUS['algorithm'])
runner = algorithm.runner()
if self.remote:
db_access = db_access=AccessMode.REMOTE
else:
db_access = db_access=AccessMode.LOCAL
(input_list, data_loader_list) = \
create_inputs_from_configuration(CONFIGURATION_DB_AUTONOMOUS, algorithm,
prefix, self.cache_root,
cache_access=AccessMode.NONE,
db_access=db_access,
unpack=True,
databases=self.databases,
socket=self.socket,
no_synchronisation_listeners=False
)
self.assertEqual(len(input_list), 0)
self.assertEqual(len(data_loader_list), 1)
#----------------------------------------------------------
class TestCreateInputsFromConfiguration_LocalDatabase(TestCreateInputsFromConfiguration_RemoteDatabase):
def setUp(self):
super(TestCreateInputsFromConfiguration_LocalDatabase, self).setUp(remote=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment