...
 
Commits (88)
......@@ -31,9 +31,9 @@
.. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
===========================================================
Authors of the Biometrics Evaluation and Testing Platform
===========================================================
==============================
Authors of the BEAT Platform
==============================
Andre Anjos <andre.anjos@idiap.ch>
Flavio Tarsetti <flavio.tarsetti@idiap.ch>
......
......@@ -182,6 +182,8 @@ class Algorithm(BackendAlgorithm):
"""
dataformat_klass = dataformat.DataFormat
def __init__(self, prefix, data, dataformat_cache=None, library_cache=None):
super(Algorithm, self).__init__(prefix, data, dataformat_cache, library_cache)
......@@ -303,89 +305,52 @@ class Algorithm(BackendAlgorithm):
"declaration: %s" % (self.name, ", ".join(all_output_names))
)
def _validate_format(self, type_name, group_name, entry_name, dataformat):
if dataformat.errors:
self.errors.append(
"found error validating data format `%s' "
"for %s `%s' on algorithm `%s': %s"
% (
type_name,
group_name,
entry_name,
self.name,
"\n".join(dataformat.errors),
)
)
def _validate_dataformats(self, group, group_name, dataformat_cache):
for name, entry in group[group_name].items():
type_name = entry["type"]
thisformat = self._update_dataformat_cache(type_name, dataformat_cache)
self._validate_format(type_name, group_name, name, thisformat)
def _validate_required_dataformats(self, dataformat_cache):
"""Makes sure we can load all requested formats
"""
for group in self.groups:
for name, input in group["inputs"].items():
if input["type"] in self.dataformats:
continue
if dataformat_cache and input["type"] in dataformat_cache: # reuse
thisformat = dataformat_cache[input["type"]]
else: # load it
thisformat = dataformat.DataFormat(self.prefix, input["type"])
if dataformat_cache is not None: # update it
dataformat_cache[input["type"]] = thisformat
self.dataformats[input["type"]] = thisformat
if thisformat.errors:
self.errors.append(
"found error validating data format `%s' "
"for input `%s' on algorithm `%s': %s"
% (input["type"], name, self.name, "\n".join(thisformat.errors))
)
if "outputs" not in group:
continue
for name, output in group["outputs"].items():
if output["type"] in self.dataformats:
continue
for name, input_ in group["inputs"].items():
self._validate_dataformats(group, "inputs", dataformat_cache)
if dataformat_cache and output["type"] in dataformat_cache: # reuse
thisformat = dataformat_cache[output["type"]]
else: # load it
thisformat = dataformat.DataFormat(self.prefix, output["type"])
if dataformat_cache is not None: # update it
dataformat_cache[output["type"]] = thisformat
if "outputs" in group:
self._validate_dataformats(group, "outputs", dataformat_cache)
self.dataformats[output["type"]] = thisformat
if thisformat.errors:
self.errors.append(
"found error validating data format `%s' "
"for output `%s' on algorithm `%s': %s"
% (
output["type"],
name,
self.name,
"\n".join(thisformat.errors),
)
)
if "loop" in group:
self._validate_dataformats(group, "loop", dataformat_cache)
if self.results:
for name, result in self.results.items():
if result["type"].find("/") != -1:
if result["type"] in self.dataformats:
continue
if dataformat_cache and result["type"] in dataformat_cache: # reuse
thisformat = dataformat_cache[result["type"]]
else:
thisformat = dataformat.DataFormat(self.prefix, result["type"])
if dataformat_cache is not None: # update it
dataformat_cache[result["type"]] = thisformat
self.dataformats[result["type"]] = thisformat
if thisformat.errors:
self.errors.append(
"found error validating data format `%s' "
"for result `%s' on algorithm `%s': %s"
% (
result["type"],
name,
self.name,
"\n".join(thisformat.errors),
)
)
result_type = result["type"]
# results can only contain base types and plots therefore, only
# process plots
if result_type.find("/") != -1:
thisformat = self._update_dataformat_cache(
result_type, dataformat_cache
)
self._validate_format(result_type, "result", name, thisformat)
def _convert_parameter_types(self):
"""Converts types to numpy equivalents, checks defaults, ranges and
......
......@@ -52,6 +52,8 @@ import six
from . import schema
from .dataformat import DataFormat
from .protocoltemplate import ProtocolTemplate
from . import prototypes
from beat.backend.python.database import Storage
......@@ -61,17 +63,27 @@ from beat.backend.python.protocoltemplate import Storage as PTStorage
def get_first_procotol_template(prefix):
pt_root_folder = os.path.join(prefix, PTStorage.asset_folder)
pts_available = os.listdir(pt_root_folder)
pts_available = sorted(os.listdir(pt_root_folder))
if not pts_available:
raise RuntimeError("Invalid prefix content, no protocol template available")
procotol_template_folder = pts_available[0]
protocol_template_versions = sorted(
os.listdir(os.path.join(pt_root_folder, procotol_template_folder))
)
version = protocol_template_versions[-1].split(".")[0]
return "{}/{}".format(procotol_template_folder, version)
selected_protocol_template = None
for procotol_template_folder in pts_available:
protocol_template_versions = sorted(
os.listdir(os.path.join(pt_root_folder, procotol_template_folder))
)
version = protocol_template_versions[-1].split(".")[0]
protocol_template_name = "{}/{}".format(procotol_template_folder, version)
protocol_template = ProtocolTemplate(prefix, protocol_template_name)
if protocol_template.valid:
selected_protocol_template = protocol_template_name
break
if selected_protocol_template is None:
raise RuntimeError("No valid protocol template found")
return selected_protocol_template
class Database(BackendDatabase):
......
......@@ -43,6 +43,7 @@ Docker helper classes
"""
import ast
import os
import simplejson as json
import socket
......@@ -52,6 +53,8 @@ import docker
import subprocess as sp # nosec
import logging
from packaging import version
from beat.core import stats
logger = logging.getLogger(__name__)
......@@ -85,7 +88,12 @@ class Host(object):
(
self.processing_environments,
self.db_environments,
) = self._discover_environments()
) = self._discover_environments_using_labels()
if not self.db_environments and not self.processing_environments:
self.processing_environments, self.db_environments = (
self._discover_environments_using_describe()
)
# (If necessary) Save the known infos about the images
if self.images_cache_filename is not None:
......@@ -157,7 +165,7 @@ class Host(object):
s.connect(("8.8.8.8", 1)) # connecting to a UDP address doesn't send packets
return s.getsockname()[0]
def _discover_environments(self):
def _discover_environments_using_describe(self):
"""Returns a dictionary containing information about docker environments
Raises:
......@@ -299,7 +307,9 @@ class Host(object):
for image in images:
# Call the "describe" application on each existing image
description = _describe(image)
if not description:
logger.debug("Description not found for", image)
continue
key = description["name"] + " (" + description["version"] + ")"
......@@ -331,6 +341,95 @@ class Host(object):
return (environments, db_environments)
def _discover_environments_using_labels(self):
"""Search BEAT runtime environments using docker labels"""
def _must_replace(key, image, environments):
environment = environments[key]
if environment["image"] not in image.tags:
logger.warn(
"Different images providing the same environment: {} VS {}".format(
environment["image"], image.tags
)
)
if self.raise_on_errors:
raise RuntimeError(
"Environments at '%s' and '%s' have the "
"same name ('%s'). Distinct environments must be "
"uniquely named. Fix this and re-start."
% (image.tags[0], environments[key]["image"], key)
)
else:
logger.debug("Keeping more recent")
current_version = "{}{}".format(
environment["version"], environment["revision"]
)
new_version = "{}{}".format(
image.labels["beat.env.version"], image.labels["beat.env.revision"]
)
current_version = version.parse(current_version)
new_version = version.parse(new_version)
return new_version > current_version
def _parse_image_info(image):
labels = image.labels
data = {
"image": image.tags[0],
"name": labels["beat.env.name"],
"version": labels["beat.env.version"],
"revision": labels["beat.env.revision"],
}
database_list = labels.get("beat.env.databases")
if database_list:
data["databases"] = ast.literal_eval(database_list)
capabilities = labels.get("beat.env.capabilities")
if capabilities:
data["capabilities"] = ast.literal_eval(capabilities)
return data
def _process_image_list(image_list):
environments = {}
for image in image_list:
if not len(image.tags):
logger.warn("Untagged image, skipping")
continue
image_info = _parse_image_info(image)
key = "{} {}".format(image_info["name"], image_info["version"])
image_name = image_info["image"]
if key in environments:
if _must_replace(key, image, environments):
environments[key] = image_info
logger.info("Updated '%s' -> '%s'", key, image_name)
else:
environments[key] = image_info
Host.images_cache[image_name] = environments[key]
logger.info("Registered '%s' -> '%s'", key, image_name)
return environments
client = docker.from_env()
databases = client.images.list(filters={"label": ["beat.env.type=database"]})
db_environments = _process_image_list(databases)
executors = client.images.list(filters={"label": ["beat.env.type=execution"]})
environments = _process_image_list(executors)
logger.debug(
"Found %d environments and %d database environments",
len(environments),
len(db_environments),
)
return environments, db_environments
def create_container(self, image, command):
if image in self: # Replace by a real image name
......
......@@ -230,6 +230,17 @@ class BaseExecutor(object):
"The input '%s' doesn't exist in the loop algorithm" % name
)
if len(loop["outputs"]) != len(self.loop_algorithm.output_map):
self.errors.append(
"The number of outputs of the loop algorithm doesn't correspond"
)
for name in self.data["outputs"].keys():
if name not in self.algorithm.output_map.keys():
self.errors.append(
"The output '%s' doesn't exist in the loop algorithm" % name
)
# Check that the mapping in coherent
if len(self.data["inputs"]) != len(self.algorithm.input_map):
self.errors.append(
......
......@@ -437,7 +437,7 @@ class DockerExecutor(RemoteExecutor):
if self.loop_algorithm is not None:
cmd.append(
"tcp://%s:%d"
"--loop=tcp://%s:%d"
% (loop_algorithm_container_ip, loop_algorithm_container_port)
)
......
......@@ -188,9 +188,14 @@ class LocalExecutor(BaseExecutor):
self.zmq_context = None
def __cleanup(self):
def __cleanup(self, early=False):
if self.loop_executor:
if early:
self.loop_socket.send_string("don")
self.loop_socket.recv() # ack
self.loop_executor.wait()
self.loop_executor.close()
for handler in [self.message_handler, self.loop_message_handler]:
if handler:
......@@ -310,15 +315,35 @@ class LocalExecutor(BaseExecutor):
cache_root=self.cache,
)
retval = self.loop_executor.setup()
try:
retval = self.loop_executor.setup()
except Exception as e:
message = _process_exception(e, self.prefix, "algorithms")
retval = False
else:
message = None
if not retval:
self.__cleanup()
raise RuntimeError("Loop algorithm setup failed")
error = "Loop algorithm {} setup failed".format(self.algorithm.name)
if message:
error += ": {}".format(message)
raise RuntimeError(error)
try:
prepared = self.loop_executor.prepare()
except Exception as e:
message = _process_exception(e, self.prefix, "algorithms")
prepared = False
else:
message = None
prepared = self.loop_executor.prepare()
if not prepared:
self.__cleanup()
raise RuntimeError("Loop algorithm prepare failed")
error = "Loop algorithm {} prepare failed".format(self.algorithm.name)
if message:
error += ": {}".format(message)
raise RuntimeError(error)
self.loop_executor.process()
......@@ -330,28 +355,50 @@ class LocalExecutor(BaseExecutor):
loop_socket=self.loop_socket,
)
retval = self.executor.setup()
if not retval:
self.__cleanup()
raise RuntimeError("Algorithm setup failed")
try:
status = self.executor.setup()
except Exception as e:
message = _process_exception(e, self.prefix, "algorithms")
status = 0
else:
message = None
if not status:
self.__cleanup(early=True)
error = "Algorithm {} setup failed".format(self.algorithm.name)
if message:
error += ": {}".format(message)
raise RuntimeError(error)
try:
prepared = self.executor.prepare()
except Exception as e:
message = _process_exception(e, self.prefix, "algorithms")
prepared = 0
else:
message = None
prepared = self.executor.prepare()
if not prepared:
self.__cleanup()
raise RuntimeError("Algorithm prepare failed")
self.__cleanup(early=True)
error = "Algorithm {} prepare failed".format(self.algorithm.name)
if message:
error += ": {}".format(message)
raise RuntimeError(error)
_start = time.time()
try:
processed = self.executor.process()
except Exception as e:
message = _process_exception(e, self.prefix, "databases")
message = _process_exception(e, self.prefix, "algorithms")
self.__cleanup()
return _create_result(1, message)
if not processed:
self.__cleanup()
raise RuntimeError("Algorithm process failed")
raise RuntimeError(
"Algorithm {} process failed".format(self.algorithm.name)
)
proc_time = time.time() - _start
......
......@@ -123,6 +123,12 @@ class SubprocessExecutor(RemoteExecutor):
guarantee that the cache is refreshed as appropriate in case the
underlying libraries change.
custom_root_folders (dict): A dictionary mapping databases name and
their location on disk
ip_address (str): IP address of the machine to connect to for the database
execution and message handlers.
python_path (str): Path to the python executable of the environment to use
for experiment execution.
Attributes:
......@@ -172,8 +178,8 @@ class SubprocessExecutor(RemoteExecutor):
library_cache=None,
custom_root_folders=None,
ip_address="127.0.0.1",
python_path=None,
):
super(SubprocessExecutor, self).__init__(
prefix,
data,
......@@ -186,14 +192,30 @@ class SubprocessExecutor(RemoteExecutor):
custom_root_folders=custom_root_folders,
)
# We need three apps to run this function: databases_provider and execute
self.EXECUTE_BIN = _which(os.path.join(os.path.dirname(sys.argv[0]), "execute"))
self.LOOP_EXECUTE_BIN = _which(
os.path.join(os.path.dirname(sys.argv[0]), "loop_execute")
)
self.DBPROVIDER_BIN = _which(
os.path.join(os.path.dirname(sys.argv[0]), "databases_provider")
)
if python_path is None:
base_path = os.path.dirname(sys.argv[0])
# We need three apps to run this function: databases_provider and execute
self.EXECUTE_BIN = _which(os.path.join(base_path, "execute"))
self.LOOP_EXECUTE_BIN = _which(os.path.join(base_path, "loop_execute"))
self.DBPROVIDER_BIN = _which(os.path.join(base_path, "databases_provider"))
else:
base_path = os.path.dirname(python_path)
self.EXECUTE_BIN = os.path.join(base_path, "execute")
self.LOOP_EXECUTE_BIN = os.path.join(base_path, "loop_execute")
self.DBPROVIDER_BIN = os.path.join(base_path, "databases_provider")
if any(
[
not os.path.exists(executable)
for executable in [
self.EXECUTE_BIN,
self.LOOP_EXECUTE_BIN,
self.DBPROVIDER_BIN,
]
]
):
raise RuntimeError("Invalid environment")
def __create_db_process(self, configuration_name=None):
databases_process = None
......@@ -384,7 +406,9 @@ class SubprocessExecutor(RemoteExecutor):
)
if self.loop_algorithm is not None:
cmd.append("tcp://" + self.ip_address + (":%d" % loop_algorithm_port))
cmd.append(
"--loop=tcp://" + self.ip_address + (":%d" % loop_algorithm_port)
)
if logger.getEffectiveLevel() <= logging.DEBUG:
cmd.insert(1, "--debug")
......
......@@ -44,6 +44,7 @@ Validation for experiments
import os
import collections
import itertools
import simplejson as json
from . import utils
......@@ -54,6 +55,9 @@ from . import database
from . import toolchain
from . import hash
EVALUATOR_PREFIX = "evaluator_"
PROCESSOR_PREFIX = "processor_"
class Storage(utils.Storage):
"""Resolves paths for experiments
......@@ -275,6 +279,10 @@ class Experiment(object):
if self.errors:
return
self._crosscheck_toolchain_loops()
if self.errors:
return
self._crosscheck_toolchain_analyzers()
if self.errors:
return
......@@ -284,6 +292,10 @@ class Experiment(object):
return
self._crosscheck_block_algorithm_pertinence()
if self.errors:
return
self._crosscheck_loop_algorithm_pertinence()
def _check_datasets(self, database_cache, dataformat_cache):
"""checks all datasets are valid"""
......@@ -427,52 +439,69 @@ class Experiment(object):
def _check_loops(self, algorithm_cache, dataformat_cache, library_cache):
"""checks all loops are valid"""
loops = self.data.get("loops", {})
for loopname, loop in loops.items():
for key in [PROCESSOR_PREFIX, EVALUATOR_PREFIX]:
algoname = loop[key + "algorithm"]
if algoname not in self.algorithms:
# loads the algorithm
if algoname in algorithm_cache:
thisalgo = algorithm_cache[algoname]
else:
thisalgo = algorithm.Algorithm(
self.prefix, algoname, dataformat_cache, library_cache
)
algorithm_cache[algoname] = thisalgo
if "loops" not in self.data:
return
for loopname, loop in self.data["loops"].items():
algoname = loop["algorithm"]
if algoname not in self.algorithms:
self.algorithms[algoname] = thisalgo
if thisalgo.errors:
self.errors.append(
"/loops/%s: algorithm `%s' is invalid:\n%s"
% (loopname, algoname, "\n".join(thisalgo.errors))
)
continue
# loads the algorithm
if algoname in algorithm_cache:
thisalgo = algorithm_cache[algoname]
else:
thisalgo = algorithm.Algorithm(
self.prefix, algoname, dataformat_cache, library_cache
)
algorithm_cache[algoname] = thisalgo
self.algorithms[algoname] = thisalgo
if thisalgo.errors:
self.errors.append(
"/loops/%s: algorithm `%s' is invalid:\n%s"
% (loopname, algoname, "\n".join(thisalgo.errors))
)
continue
else:
thisalgo = self.algorithms[algoname]
if thisalgo.errors:
continue # already done
thisalgo = self.algorithms[algoname]
if thisalgo.errors:
continue # already done
# checks all inputs correspond
for algoin, loop_input in loop[key + "inputs"].items():
if algoin not in thisalgo.input_map:
self.errors.append(
"/loop/%s/inputs/%s: algorithm `%s' does "
"not have an input named `%s' - valid algorithm inputs "
"are %s"
% (
loopname,
loop_input,
algoname,
algoin,
", ".join(thisalgo.input_map.keys()),
)
)
# checks all inputs correspond
for algoin, loop_input in loop["inputs"].items():
if algoin not in thisalgo.input_map:
self.errors.append(
"/analyzers/%s/inputs/%s: algorithm `%s' does "
"not have an input named `%s' - valid algorithm inputs "
"are %s"
% (
loopname,
loop_input,
algoname,
algoin,
", ".join(thisalgo.input_map.keys()),
# checks all outputs correspond
for algout, loop_output in loop[key + "outputs"].items():
if (
hasattr(thisalgo, "output_map")
and algout not in thisalgo.output_map
):
self.errors.append(
"/loops/%s/outputs/%s: algorithm `%s' does not "
"have an output named `%s' - valid algorithm outputs are "
"%s"
% (
loopname,
loop_output,
algoname,
algout,
", ".join(thisalgo.output_map.keys()),
)
)
)
# checks if parallelization makes sense
if loop.get("nb_slots", 1) > 1 and not thisalgo.splittable:
......@@ -685,6 +714,38 @@ class Experiment(object):
)
)
def _crosscheck_toolchain_loops(self):
"""There must exist a 1-to-1 relation to existing loops"""
toolchain_loops = self.toolchain.loops
if sorted(toolchain_loops.keys()) != sorted(self.loops.keys()):
self.errors.append(
"mismatch between the toolchain loop names (%s)"
" and the experiment's (%s)"
% (
", ".join(sorted(toolchain_loops.keys())),
", ".join(sorted(self.loops.keys())),
)
)
# the number of block endpoints and the toolchain's must match
for block_name, block in self.loops.items():
for prefix in [PROCESSOR_PREFIX, EVALUATOR_PREFIX]:
block_input_count = len(block[prefix + "inputs"])
toolchain_input_block = len(
toolchain_loops[block_name][prefix + "inputs"]
)
if block_input_count != toolchain_input_block:
self.errors.append(
"/loops/{}: toolchain loops has {} {}inputs "
"while the experiment has {} inputs".format(
block_name, toolchain_input_block, prefix, block_input_count
)
)
def _crosscheck_toolchain_analyzers(self):
"""There must exist a 1-to-1 relation to existing analyzers"""
......@@ -741,6 +802,20 @@ class Experiment(object):
algout = imapping[from_endpt[1]] # name of output on algorithm
from_dtype = self.algorithms[block["algorithm"]].output_map[algout]
from_name = "block"
elif from_endpt[0] in self.loops:
loop = self.loops[from_endpt[0]]
for prefix in [PROCESSOR_PREFIX, EVALUATOR_PREFIX]:
mapping = loop[prefix + "outputs"]
imapping = dict(zip(mapping.values(), mapping.keys()))
if from_endpt[1] in imapping:
algout = imapping[from_endpt[1]] # name of output on algorithm
from_dtype = self.algorithms[
loop[prefix + "algorithm"]
].output_map[algout]
break
from_name = "loop"
else:
self.errors.append("Unknown endpoint %s" % from_endpt[0])
continue
......@@ -757,10 +832,15 @@ class Experiment(object):
elif to_endpt[0] in self.loops:
loop = self.loops[to_endpt[0]]
mapping = loop["inputs"]
imapping = dict(zip(mapping.values(), mapping.keys()))
algoin = imapping[to_endpt[1]] # name of input on algorithm
to_dtype = self.algorithms[loop["algorithm"]].input_map[algoin]
for prefix in [PROCESSOR_PREFIX, EVALUATOR_PREFIX]:
mapping = loop[prefix + "inputs"]
imapping = dict(zip(mapping.values(), mapping.keys()))
if to_endpt[1] in imapping:
algoin = imapping[to_endpt[1]] # name of input on algorithm
to_dtype = self.algorithms[
loop[prefix + "algorithm"]
].input_map[algoin]
break
to_name = "loop"
elif to_endpt[0] in self.analyzers: # it is an analyzer
......@@ -852,6 +932,101 @@ class Experiment(object):
% (name, self.blocks[name]["algorithm"])
)
def _crosscheck_loop_algorithm_pertinence(self):
"""The number of groups and the input-output connectivity must respect
the individual synchronization channels and the block's.
"""
loops = self.data.get("loops", {})
for name, loop in loops.items():
# filter connections that end on the visited block - remember, each
# input is checked for receiving a single input connection. It is
# illegal to connect an input multiple times. At this point, you
# already know that is not the case.
input_connections = [
k["channel"]
for k in self.toolchain.connections
if k["to"].startswith(name + ".")
]
# filter connections that start on the visited block, retain output
# name so we can check synchronization and then group
output_connections = set(
[
(k["from"].replace(name + ".", ""), k["channel"])
for k in self.toolchain.connections
if k["from"].startswith(name + ".")
]
)
output_connections = [k[1] for k in output_connections]
# note: dataformats have already been checked - only need to check
# for the grouping properties between inputs and outputs
# create channel groups
chain_in = collections.Counter(input_connections)
chain_out = collections.Counter(output_connections)
chain_groups_count = [(v, chain_out.get(k, 0)) for k, v in chain_in.items()]
# now check the algorithms for conformance
processor_algorithm_name = loop[PROCESSOR_PREFIX + "algorithm"]
evaluator_algorithm_name = loop[EVALUATOR_PREFIX + "algorithm"]
processor_algo_groups_list = self.algorithms[
processor_algorithm_name
].groups
evaluator_algo_groups_list = self.algorithms[
evaluator_algorithm_name
].groups
groups_count = []
for processor_algo_groups, evaluator_algo_groups in itertools.zip_longest(
processor_algo_groups_list, evaluator_algo_groups_list
):
inputs = 0
outputs = 0
if processor_algo_groups:
inputs = len(processor_algo_groups["inputs"])
outputs = len(processor_algo_groups.get("outputs", []))
if evaluator_algo_groups:
inputs += len(evaluator_algo_groups["inputs"])
outputs += len(evaluator_algo_groups.get("outputs", []))
groups_count.append((inputs, outputs))
if collections.Counter(chain_groups_count) != collections.Counter(
groups_count
):
self.errors.append(
"synchronization mismatch in input/output "
"grouping between loop `{}', algorithm `{}' "
"and loop algorithm `{}'".format(
name, processor_algorithm_name, evaluator_algorithm_name
)
)
for processor_algo_groups, evaluator_algo_groups in zip(
processor_algo_groups_list, evaluator_algo_groups_list
):
processor_algo_loop = processor_algo_groups["loop"]
evaluator_algo_loop = evaluator_algo_groups["loop"]
for channel in ["request", "answer"]:
if (
processor_algo_loop[channel]["type"]
!= evaluator_algo_loop[channel]["type"]
):
self.errors.append(
"{} loop channel type incompatible between {} and {}".format(
channel,
processor_algorithm_name,
evaluator_algorithm_name,
)
)
def _crosscheck_analyzer_algorithm_pertinence(self):
"""
The number of groups and the input-output connectivity must respect the
......@@ -908,7 +1083,7 @@ class Experiment(object):
return not bool(self.errors)
def _inputs(self, name):
def _inputs(self, name, input_prefix=""):
"""Calculates and returns the inputs for a given block"""
# filter connections that end on the visited block
......@@ -942,9 +1117,15 @@ class Experiment(object):
break
if config_data is None:
raise KeyError("did not find `%s' among blocks or analyzers" % name)
raise KeyError("did not find `%s' among blocks, loops or analyzers" % name)
for algo_endpt, block_endpt in config_data["inputs"].items():
# if get_loop_data:
# inputs = config_data[EVALUATOR_PREFIX + "inputs"]
# else:
# inputs = config_data[PROCESSOR_PREFIX + "inputs"]
inputs = config_data[input_prefix + "inputs"]
for algo_endpt, block_endpt in inputs.items():
block, output, channel = connections[block_endpt]
if block in self.toolchain.datasets:
......@@ -985,9 +1166,17 @@ class Experiment(object):
return retval
def _block_outputs(self, name):
def _block_outputs(self, name, output_prefix=""):
"""Calculates and returns the outputs for a given block"""
for item in [self.blocks, self.loops]:
if name in item:
config_data = item[name]
break
if config_data is None:
raise KeyError("did not find `%s' among blocks or loops" % name)
# filter connections that end on the visited block
connections = [
k for k in self.toolchain.connections if k["from"].startswith(name + ".")
......@@ -1007,7 +1196,13 @@ class Experiment(object):
retval = dict()
# notice: there can be multiply connected outputs
for algo_endpt, block_endpt in self.blocks[name]["outputs"].items():
# if get_loop_data:
# outputs = config_data[EVALUATOR_PREFIX + "outputs"]
# else:
# outputs = config_data[PROCESSOR_PREFIX + "outputs"]
outputs = config_data[output_prefix + "outputs"]
for algo_endpt, block_endpt in outputs.items():
block, input, channel = connections[block_endpt]
retval[algo_endpt] = dict(
channel=channel, endpoint=block_endpt # the block outtake name
......@@ -1041,18 +1236,8 @@ class Experiment(object):
config_data = item[name]
break
# resolve parameters taking globals in consideration
parameters = self.data["globals"].get(config_data["algorithm"])
if parameters is None:
parameters = dict()
else:
parameters = dict(parameters) # copy
parameters.update(config_data.get("parameters", {}))
# resolve the execution information
queue = config_data.get("queue", self.data["globals"]["queue"])
env = config_data.get("environment", self.data["globals"]["environment"])
nb_slots = config_data.get("nb_slots", 1)
toolchain_data = self.toolchain.algorithm_item(name)
......@@ -1060,44 +1245,70 @@ class Experiment(object):
if toolchain_data is None:
raise KeyError("did not find `%s' among blocks, loops or analyzers" % name)
retval = dict(
inputs=self._inputs(name),
channel=toolchain_data["synchronized_channel"],
algorithm=config_data["algorithm"],
parameters=parameters,
queue=queue,
environment=env,
nb_slots=nb_slots,
)
if name in self.loops:
def build_block_data(name, config_data, algorithm_prefix):
# resolve parameters taking globals in consideration
algorithm_name = config_data[algorithm_prefix + "algorithm"]
loop = self.toolchain.get_loop_for_block(name)
parameters = self.data["globals"].get(algorithm_name)
if parameters is None:
parameters = dict()
else:
parameters = dict(parameters) # copy
if loop is not None:
loop_name = loop["name"]
loop_toolchain_data = self.toolchain.algorithm_item(loop_name)
loop_config_data = self.data["loops"][loop_name]
loop_algorithm = loop_config_data["algorithm"]
parameters.update(config_data.get(algorithm_prefix + "parameters", {}))
parameters = self.data["globals"].get(loop_algorithm, dict())
parameters.update(loop_config_data.get("parameters", dict()))
environment = config_data.get(
algorithm_prefix + "environment",
self.data["globals"]["environment"],
)
loop_data = dict(
inputs=self._inputs(loop_name),
channel=loop_toolchain_data["synchronized_channel"],
algorithm=loop_algorithm,
return dict(
inputs=self._inputs(name, algorithm_prefix),
outputs=self._block_outputs(name, algorithm_prefix),
channel=toolchain_data["synchronized_channel"],
algorithm=algorithm_name,
parameters=parameters,
queue=queue,
environment=environment,
)
retval = build_block_data(name, config_data, PROCESSOR_PREFIX)
retval["nb_slots"] = nb_slots
retval["loop"] = build_block_data(name, config_data, EVALUATOR_PREFIX)
else:
env = config_data.get("environment", self.data["globals"]["environment"])
# resolve parameters taking globals in consideration
parameters = self.data["globals"].get(config_data["algorithm"])
if parameters is None:
parameters = dict()
else:
parameters = dict(parameters) # copy
parameters.update(config_data.get("parameters", {}))
retval = dict(
inputs=self._inputs(name),
channel=toolchain_data["synchronized_channel"],
algorithm=config_data["algorithm"],
parameters=parameters,
queue=queue,
environment=env,
nb_slots=nb_slots,
)
retval["loop"] = loop_data
if name in self.blocks:
retval["outputs"] = self._block_outputs(name)
else:
# Analyzers have only 1 output file/cache. This is the result of an
# optimization as most of the outputs are single numbers.
# Furthermore, given we need to read it out on beat.web, having a
# single file optimizes resource usage. The synchronization channel
# for the analyzer itself is respected.
retval["result"] = dict() # missing the hash/path
if name in self.blocks:
retval["outputs"] = self._block_outputs(name)
else:
# Analyzers have only 1 output file/cache. This is the result of an
# optimization as most of the outputs are single numbers.
# Furthermore, given we need to read it out on beat.web, having a
# single file optimizes resource usage. The synchronization channel
# for the analyzer itself is respected.
retval["result"] = dict() # missing the hash/path
return retval
......@@ -1125,47 +1336,63 @@ class Experiment(object):
exec_order[key] = dict(
dependencies=exec_order[key], configuration=self._configuration(key)
)
# import ipdb; ipdb.set_trace()
for key, value in exec_order.items():
# now compute missing hashes - because we're in execution order,
# there should be no missing input hashes in any of the blocks.
config = value["configuration"]
if "outputs" in config: # it is a block
block_outputs = {}
for output, output_value in config["outputs"].items():
output_value["hash"] = hash.hashBlockOutput(
key,
config["algorithm"],
self.algorithms[config["algorithm"]].hash(),
config["parameters"],
config["environment"],
dict([(k, v["hash"]) for k, v in config["inputs"].items()]),
output,
)
output_value["path"] = hash.toPath(output_value["hash"], "")
# set the inputs for the following blocks