Commit 6cbe1625 authored by Samuel GAIST's avatar Samuel GAIST

[execution][algorithm] Update code to handle new loop user types

parent f87dffbb
Pipeline #31774 passed with stage
in 10 minutes and 24 seconds
......@@ -124,8 +124,6 @@ class AlgorithmExecutor(object):
self.prefix, self.data["algorithm"], dataformat_cache, library_cache
)
main_channel = self.data["channel"]
if db_socket:
db_access_mode = AccessMode.REMOTE
databases = None
......@@ -249,7 +247,7 @@ class AlgorithmExecutor(object):
inputs=self.input_list, outputs=self.output_list
)
elif self.algorithm.type == Algorithm.SEQUENTIAL:
elif self.algorithm.is_sequential:
if self.analysis:
result = self.runner.process(
inputs=self.input_list,
......@@ -261,6 +259,7 @@ class AlgorithmExecutor(object):
inputs=self.input_list,
data_loaders=self.data_loaders,
outputs=self.output_list,
loop_channel=self.loop_channel,
)
if not result:
......
......@@ -59,7 +59,6 @@ from ..data import CachedDataSink
from ..data import CachedDataSource
from ..helpers import convert_experiment_configuration_to_container
from ..helpers import AccessMode
from . import prefix
......@@ -67,48 +66,28 @@ from . import prefix
logger = logging.getLogger(__name__)
#----------------------------------------------------------
# ----------------------------------------------------------
CONFIGURATION = {
'algorithm': '',
'channel': 'main',
'parameters': {
"algorithm": "",
"channel": "main",
"parameters": {},
"inputs": {"in": {"path": "INPUT", "channel": "main"}},
"outputs": {"out": {"path": "OUTPUT", "channel": "main"}},
"loop": {
"algorithm": "",
"channel": "main",
"parameters": {"threshold": 1},
"inputs": {"in": {"path": "INPUT", "channel": "main"}},
},
'inputs': {
'in': {
'path': 'INPUT',
'channel': 'main',
}
},
'outputs': {
'out': {
'path': 'OUTPUT',
'channel': 'main'
}
},
'loop': {
'algorithm': '',
'channel': 'main',
'parameters': {
'threshold': 1
},
'inputs': {
'in': {
'path': 'INPUT',
'channel': 'main'
}
}
}
}
#----------------------------------------------------------
# ----------------------------------------------------------
class TestExecution(unittest.TestCase):
def setUp(self):
self.cache_root = tempfile.mkdtemp(prefix=__name__)
self.working_dir = tempfile.mkdtemp(prefix=__name__)
......@@ -120,7 +99,6 @@ class TestExecution(unittest.TestCase):
self.loop_socket = None
self.zmq_context = None
def tearDown(self):
shutil.rmtree(self.cache_root)
shutil.rmtree(self.working_dir)
......@@ -135,7 +113,6 @@ class TestExecution(unittest.TestCase):
handler.destroy()
handler = None
for socket in [self.executor_socket, self.loop_socket]:
if socket is not None:
socket.setsockopt(zmq.LINGER, 0)
......@@ -145,15 +122,18 @@ class TestExecution(unittest.TestCase):
self.zmq_context.destroy()
self.zmq_context = None
def writeData(self, input_name, indices, start_value):
filename = os.path.join(self.cache_root, CONFIGURATION['inputs'][input_name]['path'] + '.data')
filename = os.path.join(
self.cache_root, CONFIGURATION["inputs"][input_name]["path"] + ".data"
)
dataformat = DataFormat(prefix, 'user/single_integer/1')
dataformat = DataFormat(prefix, "user/single_integer/1")
self.assertTrue(dataformat.valid)
data_sink = CachedDataSink()
self.assertTrue(data_sink.setup(filename, dataformat, indices[0][0], indices[-1][1]))
self.assertTrue(
data_sink.setup(filename, dataformat, indices[0][0], indices[-1][1])
)
for i in indices:
data = dataformat.type()
......@@ -167,40 +147,39 @@ class TestExecution(unittest.TestCase):
data_sink.close()
del data_sink
def process(self, algorithm_name, loop_algorithm_name):
self.writeData('in', [(0, 0), (1, 1), (2, 2), (3, 3)], 1000)
self.writeData("in", [(0, 0), (1, 1), (2, 2), (3, 3)], 1000)
# -------------------------------------------------------------------------
config = deepcopy(CONFIGURATION)
config['algorithm'] = algorithm_name
config['loop']['algorithm'] = loop_algorithm_name
config["algorithm"] = algorithm_name
config["loop"]["algorithm"] = loop_algorithm_name
config = convert_experiment_configuration_to_container(config)
with open(os.path.join(self.working_dir, 'configuration.json'), 'wb') as f:
data = simplejson.dumps(config, indent=4).encode('utf-8')
with open(os.path.join(self.working_dir, "configuration.json"), "wb") as f:
data = simplejson.dumps(config, indent=4).encode("utf-8")
f.write(data)
working_prefix = os.path.join(self.working_dir, 'prefix')
working_prefix = os.path.join(self.working_dir, "prefix")
if not os.path.exists(working_prefix):
os.makedirs(working_prefix)
algorithm = Algorithm(prefix, algorithm_name)
assert(algorithm.valid)
self.assertTrue(algorithm.valid, algorithm.errors)
algorithm.export(working_prefix)
# -------------------------------------------------------------------------
loop_algorithm = Algorithm(prefix, loop_algorithm_name)
assert(loop_algorithm.valid)
self.assertTrue(loop_algorithm.valid, loop_algorithm.errors)
loop_algorithm.export(working_prefix)
# -------------------------------------------------------------------------
self.message_handler = MessageHandler('127.0.0.1')
self.message_handler = MessageHandler("127.0.0.1")
self.message_handler.start()
self.loop_message_handler = LoopMessageHandler('127.0.0.1')
self.loop_message_handler = LoopMessageHandler("127.0.0.1")
self.zmq_context = zmq.Context()
self.executor_socket = self.zmq_context.socket(zmq.PAIR)
......@@ -208,12 +187,19 @@ class TestExecution(unittest.TestCase):
self.loop_socket = self.zmq_context.socket(zmq.PAIR)
self.loop_socket.connect(self.loop_message_handler.address)
self.loop_executor = LoopExecutor(self.loop_message_handler, self.working_dir, cache_root=self.cache_root)
self.loop_executor = LoopExecutor(
self.loop_message_handler, self.working_dir, cache_root=self.cache_root
)
self.assertTrue(self.loop_executor.setup())
self.assertTrue(self.loop_executor.prepare())
self.loop_executor.process()
executor = AlgorithmExecutor(self.executor_socket, self.working_dir, cache_root=self.cache_root, loop_socket=self.loop_socket)
executor = AlgorithmExecutor(
self.executor_socket,
self.working_dir,
cache_root=self.cache_root,
loop_socket=self.loop_socket,
)
self.assertTrue(executor.setup())
self.assertTrue(executor.prepare())
......@@ -221,7 +207,14 @@ class TestExecution(unittest.TestCase):
self.assertTrue(executor.process())
cached_file = CachedDataSource()
self.assertTrue(cached_file.setup(os.path.join(self.cache_root, CONFIGURATION['outputs']['out']['path'] + '.data'), prefix))
self.assertTrue(
cached_file.setup(
os.path.join(
self.cache_root, CONFIGURATION["outputs"]["out"]["path"] + ".data"
),
prefix,
)
)
for i in range(len(cached_file)):
data, start, end = cached_file[i]
......@@ -229,13 +222,12 @@ class TestExecution(unittest.TestCase):
self.assertEqual(start, i)
self.assertEqual(end, i)
def test_autonomous_loop_user(self):
self.process("autonomous/loop_user/1", "autonomous/loop/1")
def test_autonomous_loop(self):
self.process('autonomous/loop_user/1',
'autonomous/loop/1')
def test_sequential_loop_user(self):
self.process("sequential/loop_user/1", "autonomous/loop/1")
def test_autonomous_loop_invalid_output(self):
with self.assertRaises(RemoteException):
self.process('autonomous/loop_user/1',
'autonomous/invalid_loop_output/1')
\ No newline at end of file
self.process("autonomous/loop_user/1", "autonomous/invalid_loop_output/1")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment