Commit f00077a6 authored by Samuel GAIST's avatar Samuel GAIST
Browse files

[experiment] Implement support for loop blocks handling

parent c8d4eda4
......@@ -185,6 +185,7 @@ class Experiment(object):
self.storage = None
self.datasets = {}
self.blocks = {}
self.loops = {}
self.analyzers = {}
self.databases = {}
......@@ -235,6 +236,7 @@ class Experiment(object):
# checks all internal aspects of the experiment
self._check_datasets(database_cache, dataformat_cache)
self._check_blocks(algorithm_cache, dataformat_cache, library_cache)
self._check_loops(algorithm_cache, dataformat_cache, library_cache)
self._check_analyzers(algorithm_cache, dataformat_cache, library_cache)
self._check_global_parameters()
self._load_toolchain(toolchain_data)
......@@ -324,7 +326,7 @@ class Experiment(object):
self.algorithms[algoname] = thisalgo
if thisalgo.errors:
self.errors.append("/blocks/%s: algorithm `%s' is invalid: %s" % \
(blockname, algoname, str(thisalgo.errors)))
(blockname, algoname, "\n".join(thisalgo.errors)))
else:
thisalgo = self.algorithms[algoname]
if thisalgo.errors: continue # already done
......@@ -362,6 +364,63 @@ class Experiment(object):
self.blocks[blockname] = block
def _check_loops(self, algorithm_cache, dataformat_cache, library_cache):
"""checks all loops are valid"""
if 'loops' not in self.data:
return
for loopname, loop in self.data['loops'].items():
algoname = loop['algorithm']
if algoname not in self.algorithms:
# loads the algorithm
if algoname in algorithm_cache:
thisalgo = algorithm_cache[algoname]
else:
thisalgo = algorithm.Algorithm(self.prefix,
algoname,
dataformat_cache,
library_cache)
algorithm_cache[algoname] = thisalgo
self.algorithms[algoname] = thisalgo
if thisalgo.errors:
self.errors.append("/loops/%s: algorithm `%s' is invalid:\n%s" % \
(loopname, algoname, "\n".join(thisalgo.errors)))
continue
else:
thisalgo = self.algorithms[algoname]
if thisalgo.errors: continue # already done
# checks all inputs correspond
for algoin, loop_input in loop['inputs'].items():
if algoin not in thisalgo.input_map:
self.errors.append("/analyzers/%s/inputs/%s: algorithm `%s' does " \
"not have an input named `%s' - valid algorithm inputs " \
"are %s" % (loopname, loop_input, algoname, algoin,
", ".join(thisalgo.input_map.keys())))
# checks if parallelization makes sense
if loop.get('nb_slots', 1) > 1 and not thisalgo.splittable:
self.errors.append("/loop/%s/nb_slots: you have set the number " \
"of slots for algorithm `%s' to %d, but it is not " \
"splittable" % (analyzername, thisalgo.name,
loop['nb_slots']))
# check parameter consistence
for parameter, value in loop.get('parameters', {}).items():
try:
thisalgo.clean_parameter(parameter, value)
except Exception as e:
self.errors.append("/loop/%s/parameters/%s: cannot convert " \
"value `%s' to required type: %s" % \
(loopname, parameter, value, e))
self.loops[loopname] = loop
def _check_analyzers(self, algorithm_cache, dataformat_cache, library_cache):
"""checks all analyzers are valid"""
......@@ -557,13 +616,16 @@ class Experiment(object):
)['outputs'][from_endpt[1]]
from_name = "dataset"
else: # it is a block
elif from_endpt[0] in self.blocks: # it is a block
block = self.blocks[from_endpt[0]]
mapping = block['outputs']
imapping = dict(zip(mapping.values(), mapping.keys()))
algout = imapping[from_endpt[1]] #name of output on algorithm
from_dtype = self.algorithms[block['algorithm']].output_map[algout]
from_name = "block"
else:
self.errors.append("Unknown endpoint %s" % to_endpt[0])
continue
to_endpt = connection['to'].split('.', 1)
......@@ -575,13 +637,24 @@ class Experiment(object):
to_dtype = self.algorithms[block['algorithm']].input_map[algoin]
to_name = "block"
else: # it is an analyzer
elif to_endpt[0] in self.loops:
loop = self.loops[to_endpt[0]]
mapping = loop['inputs']
imapping = dict(zip(mapping.values(), mapping.keys()))
algoin = imapping[to_endpt[1]] #name of input on algorithm
to_dtype = self.algorithms[loop['algorithm']].input_map[algoin]
to_name = "loop"
elif to_endpt[0] in self.analyzers: # it is an analyzer
analyzer = self.analyzers[to_endpt[0]]
mapping = analyzer['inputs']
imapping = dict(zip(mapping.values(), mapping.keys()))
algoin = imapping[to_endpt[1]] #name of input on algorithm
to_dtype = self.algorithms[analyzer['algorithm']].input_map[algoin]
to_name = "analyzer"
else:
self.errors.append("Unknown endpoint %s" % to_endpt[0])
continue
if from_dtype == to_dtype: continue #OK
......@@ -714,7 +787,12 @@ class Experiment(object):
retval = dict()
config_data = self.blocks.get(name, self.analyzers.get(name))
# config_data = self.blocks.get(name, self.analyzers.get(name))
for item in [self.blocks, self.loops, self.analyzers]:
if name in item:
config_data = item[name]
break
if config_data is None:
raise KeyError("did not find `%s' among blocks or analyzers" % name)
......@@ -806,10 +884,10 @@ class Experiment(object):
KeyError: if the block name does not exist in this experiment.
"""
if name in self.blocks:
config_data = self.blocks[name]
else:
config_data = self.analyzers[name]
for item in [self.blocks, self.loops, self.analyzers]:
if name in item:
config_data = item[name]
break
# resolve parameters taking globals in consideration
parameters = self.data['globals'].get(config_data['algorithm'])
......@@ -826,10 +904,10 @@ class Experiment(object):
self.data['globals']['environment'])
nb_slots = config_data.get('nb_slots', 1)
toolchain_data = self.toolchain.blocks.get(name,
self.toolchain.analyzers.get(name))
toolchain_data = self.toolchain.algorithm_item(name)
if toolchain_data is None:
raise KeyError("did not find `%s' among blocks or analyzers" % name)
raise KeyError("did not find `%s' among blocks, loops or analyzers" % name)
retval = dict(
inputs=self._inputs(name),
......@@ -841,6 +919,25 @@ class Experiment(object):
nb_slots=nb_slots,
)
loop = self.toolchain.get_loop_for_block(name)
if loop is not None:
loop_name = loop['name']
loop_toolchain_data = self.toolchain.algorithm_item(loop_name)
loop_config_data = self.data['loops'][loop_name]
loop_algorithm = loop_config_data['algorithm']
parameters = self.data['globals'].get(loop_algorithm, dict())
parameters.update(loop_config_data.get('parameters', dict()))
loop_data = dict(
inputs=self._inputs(loop_name),
channel=loop_toolchain_data['synchronized_channel'],
algorithm=loop_algorithm,
parameters=parameters
)
retval['loop'] = loop_data
if name in self.blocks:
retval['outputs'] = self._block_outputs(name)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment