Commit 300c9661 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Use a global prefix to keep things around, add an iris test

parent b756cd0c
Pipeline #42048 failed with stage
in 3 minutes and 31 seconds
......@@ -50,8 +50,8 @@ import os
import numpy
import pkg_resources
import simplejson as json
import six
from beat.backend.python import config
from beat.backend.python.algorithm import Algorithm as BackendAlgorithm
from beat.backend.python.algorithm import Runner # noqa
from beat.backend.python.algorithm import Storage
......@@ -104,8 +104,6 @@ class Algorithm(BackendAlgorithm):
Parameters:
prefix (str): Establishes the prefix of your installation.
data (:py:class:`object`, Optional): The piece of data representing the
algorithm. It must validate against the schema defined for algorithms.
If a string is passed, it is supposed to be a valid path to an
......@@ -115,17 +113,6 @@ class Algorithm(BackendAlgorithm):
its source format or as a binary blob). If ``None`` is passed, loads
our default prototype for algorithms (source code will be in Python).
dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping
dataformat names to loaded dataformats. This parameter is optional and,
if passed, may greatly speed-up algorithm loading times as dataformats
that are already loaded may be re-used.
library_cache (:py:class:`dict`, Optional): A dictionary mapping library
names to loaded libraries. This parameter is optional and, if passed,
may greatly speed-up library loading times as libraries that are
already loaded may be re-used.
Attributes:
name (str): The algorithm name
......@@ -183,10 +170,7 @@ class Algorithm(BackendAlgorithm):
dataformat_klass = dataformat.DataFormat
def __init__(self, prefix, data, dataformat_cache=None, library_cache=None):
super(Algorithm, self).__init__(prefix, data, dataformat_cache, library_cache)
def _load(self, data, dataformat_cache, library_cache):
def _load(self, data):
"""Loads the algorithm"""
self.errors = []
......@@ -208,10 +192,10 @@ class Algorithm(BackendAlgorithm):
data, code = data # break down into two components
if isinstance(data, six.string_types): # user has passed a file pointer
if isinstance(data, str): # user has passed a file pointer
self._name = data
self.storage = Storage(self.prefix, self._name)
self.storage = Storage(self._name)
if not self.storage.json.exists():
self.errors.append("Algorithm declaration file not found: %s" % data)
return
......@@ -220,7 +204,7 @@ class Algorithm(BackendAlgorithm):
# At this point, `data' can be a dictionary or ``None``
if data is None: # loads the default declaration for an algorithm
algorithm_data = load_algorithm_prototype(self.prefix)
algorithm_data = load_algorithm_prototype(config.get_config()["prefix"])
self.data, self.errors = schema.validate("algorithm", algorithm_data)
assert not self.errors, "\n * %s" % "\n *".join(self.errors) # nosec
else: # just assign it
......@@ -272,11 +256,11 @@ class Algorithm(BackendAlgorithm):
[(k, v["type"]) for g in self.groups for k, v in g.get("loop", {}).items()]
)
self._validate_required_dataformats(dataformat_cache)
self._validate_required_dataformats()
self._convert_parameter_types()
# finally, the libraries
self._validate_required_libraries(library_cache)
self._validate_required_libraries()
self._check_language_consistence()
def _check_endpoint_uniqueness(self):
......@@ -318,26 +302,26 @@ class Algorithm(BackendAlgorithm):
)
)
def _validate_dataformats(self, group, group_name, dataformat_cache):
def _validate_dataformats(self, group, group_name):
for name, entry in group[group_name].items():
type_name = entry["type"]
thisformat = self._update_dataformat_cache(type_name, dataformat_cache)
thisformat = self._update_dataformat_cache(type_name)
self._validate_format(type_name, group_name, name, thisformat)
def _validate_required_dataformats(self, dataformat_cache):
def _validate_required_dataformats(self):
"""Makes sure we can load all requested formats
"""
for group in self.groups:
for name, input_ in group["inputs"].items():
self._validate_dataformats(group, "inputs", dataformat_cache)
self._validate_dataformats(group, "inputs")
if "outputs" in group:
self._validate_dataformats(group, "outputs", dataformat_cache)
self._validate_dataformats(group, "outputs")
if "loop" in group:
self._validate_dataformats(group, "loop", dataformat_cache)
self._validate_dataformats(group, "loop")
if self.results:
......@@ -346,9 +330,7 @@ class Algorithm(BackendAlgorithm):
# results can only contain base types and plots therefore, only
# process plots
if result_type.find("/") != -1:
thisformat = self._update_dataformat_cache(
result_type, dataformat_cache
)
thisformat = self._update_dataformat_cache(result_type)
self._validate_format(result_type, "result", name, thisformat)
def _convert_parameter_types(self):
......@@ -430,7 +412,7 @@ class Algorithm(BackendAlgorithm):
)
)
def _validate_required_libraries(self, library_cache):
def _validate_required_libraries(self):
# all used libraries must be loadable; cannot use self as a library
......@@ -438,9 +420,7 @@ class Algorithm(BackendAlgorithm):
for name, value in self.uses.items():
self.libraries[value] = library_cache.setdefault(
value, library.Library(self.prefix, value, library_cache)
)
self.libraries[value] = library.Library[value]
if not self.libraries[value].valid:
self.errors.append(
......
from collections import namedtuple
from .experiment import Experiment
from .toolchain import Toolchain
_Node = namedtuple("Node", ["block", "name", "dataformat"])
class _Edge:
"""An edge of a BEAT experiment/toolchain graph
"""
def __init__(self, block, name, dataformat, channel, **kwargs):
super().__init__(**kwargs)
self.before = _Node(block, name, dataformat)
self.channel = channel
def connect_to(self, block, name, dataformat):
after = _Node(block, name, dataformat)
if dataformat != self.before[-1]:
raise RuntimeError(f"Cannot connect the output of {self.before} to {after}")
self.after = after
def __repr__(self):
return f"{self.before} -> {self.after} in {self.channel}"
class _Block:
"""A block of a BEAT experiment/toolchain.
It holds information about its input/output connections.
"""
def __init__(self, component, name=None, **kwargs):
super().__init__(**kwargs)
self._component = component
self._edges = []
self._name = name or component.name.replace("/", "_")
class DatasetBlock(_Block):
def __init__(self, database, protocol, set, name=None, **kwargs):
super().__init__(component=database, name=name, **kwargs)
if protocol not in database.protocol_names:
raise ValueError(
f"Unknown protocol name ({protocol}) for database: {database}"
)
self._protocol = protocol
if set not in database.set_names(protocol):
raise ValueError(
f"Unknown set name ({set}) for "
f"database: {database} and protocol: {protocol}"
)
self._set = set
self._channel = f"{self._component.name}_{self._protocol}_{self._set}"
self._channel = self._channel.replace("/", "_")
if name is None:
self._name = self._channel
def __getattr__(self, name):
# check if name is in the set outputs
_set = self._component.set(self._protocol, self._set)
if name not in _set["outputs"]:
raise AttributeError
edge = _Edge(
block=self,
name=name,
dataformat=_set["outputs"][name],
channel=self._channel,
)
# keep the edge
setattr(self, name, edge)
self._edges.append(edge)
return edge
class _AlgorithmAnalyzerBlock(_Block):
def validate_inputs(self, **kwargs):
input_names = set(kwargs.keys())
valid_input_names = set(self._component.input_map.keys())
if input_names != valid_input_names:
raise ValueError(
f"Inputs {input_names} do not match "
f"algorithm's inputs: {valid_input_names}"
)
# set channel based on inputs if not set already
if self.channel is None:
edge_channels = set(v.channel for v in kwargs.values())
if not len(edge_channels) == 1:
raise ValueError(
f"The inputs are coming from more than 1 synchronization channel "
f"({edge_channels}). During the initialization of {self}, specify "
f"the dataset responsible for the main synchronization channel "
"of this block."
)
self.channel = edge_channels.pop()
def connect_inputs(self, **kwargs):
self.input_edges = []
# connect input edges to inputs
for input_name, edge in kwargs.items():
edge.connect_to(self, input_name, self._component.input_map[input_name])
self.input_edges.append(edge)
class AlgorithmBlock(_AlgorithmAnalyzerBlock):
def __init__(self, algorithm, sync_with=None, parameters=None, name=None, **kwargs):
super().__init__(component=algorithm, name=name, **kwargs)
self.channel = None if sync_with is None else sync_with._channel
self.parameters = parameters
def __call__(self, **kwargs):
self.validate_inputs(**kwargs)
self.connect_inputs(**kwargs)
# find the channel of output through inputs
grp = self._component.output_group
for input_name, edge in kwargs.items():
if input_name in grp["inputs"]:
output_channel = edge.channel
break
# construct output edges
output_edges = []
for output_name, dataformat in self._component.output_map.items():
edge = _Edge(
block=self,
name=output_name,
dataformat=dataformat,
channel=output_channel,
)
output_edges.append(edge)
self._edges.extend(output_edges)
if len(output_edges) == 1:
return output_edges[0]
return output_edges
class AnalyzerBlock(_AlgorithmAnalyzerBlock):
def __init__(
self, analyzer, name=None, toolchain_name=None, experiment_name=None, **kwargs
):
super().__init__(component=analyzer, name=name, **kwargs)
self.channel = None
self.toolchain_name = toolchain_name
self.experiment_name = experiment_name
def __call__(self, **kwargs):
self.validate_inputs(**kwargs)
self.connect_inputs(**kwargs)
def create_experiment(*analyzers, experiment_name, toolchain_name=None):
# construct the toolchain and experiment
toolchain_name = toolchain_name or experiment_name
experiment = _create_experiment(
[edge for analyzer in analyzers for edge in analyzer.input_edges],
toolchain_name=toolchain_name,
experiment_name=experiment_name,
)
return experiment
def _get_all_edges(edges, all_edges=None):
if all_edges is None:
all_edges = set()
for edge in edges:
all_edges.add(edge)
new_edges = [
e for e in getattr(edge.before.block, "input_edges", []) if e not in edges
]
all_edges.update(_get_all_edges(new_edges, all_edges))
return all_edges
def _create_experiment(edges, toolchain_name, experiment_name):
all_edges = list(_get_all_edges(edges))
# find all blocks
datasets = set()
algorithms = set()
analyzers = set()
for edge in all_edges:
before_block = edge.before.block
if isinstance(before_block, DatasetBlock):
datasets.add(before_block)
after_block = edge.after.block
if isinstance(after_block, AlgorithmBlock):
algorithms.add(after_block)
elif isinstance(after_block, AnalyzerBlock):
analyzers.add(after_block)
# create unique names for all blocks
all_blocks = list(datasets) + list(algorithms) + list(analyzers)
names, names_map = list(), dict()
for block in all_blocks:
name = block._name
# if the name already exists
while name in names:
name += "_2"
names.append(name)
names_map[block] = name
# find all connections
connections = []
for edge in all_edges:
before_name = names_map[edge.before.block]
after_name = names_map[edge.after.block]
connections.append(
{
"channel": edge.channel,
"from": f"{before_name}.{edge.before.name}",
"to": f"{after_name}.{edge.after.name}",
}
)
# create datasets
toolchain_datasets, experiment_datasets = list(), dict()
for block in datasets:
name = names_map[block]
outputs = [edge.before.name for edge in block._edges]
toolchain_datasets.append(dict(name=name, outputs=outputs))
experiment_datasets[name] = dict(
database=block._component.name, protocol=block._protocol, set=block._set,
)
# create blocks (algorithms)
toolchain_blocks, experiment_blocks = list(), dict()
for block in algorithms:
name = names_map[block]
channel = block.channel
inputs = [edge.after.name for edge in block.input_edges]
outputs = [edge.before.name for edge in block._edges]
toolchain_blocks.append(
dict(
name=name, synchronized_channel=channel, inputs=inputs, outputs=outputs
)
)
experiment_blocks[name] = dict(
algorithm=block._component.name,
inputs={edge.after.name: edge.before.name for edge in block.input_edges},
outputs={edge.before.name: edge.after.name for edge in block._edges},
)
# create analyzers
toolchain_analyzers, experiment_analyzers = list(), dict()
for block in analyzers:
name = names_map[block]
channel = block.channel
inputs = [edge.after.name for edge in block.input_edges]
toolchain_analyzers.append(
dict(name=name, synchronized_channel=channel, inputs=inputs)
)
experiment_analyzers[name] = dict(
algorithm=block._component.name,
inputs={edge.after.name: edge.before.name for edge in block.input_edges},
)
# create toolchain
data = dict(
analyzers=toolchain_analyzers,
blocks=toolchain_blocks,
datasets=toolchain_datasets,
connections=connections,
representation={"blocks": {}, "channel_colors": {}, "connections": {}},
)
toolchain = Toolchain.new(data=data, name=toolchain_name)
data = dict(
analyzers=experiment_analyzers,
blocks=experiment_blocks,
datasets=experiment_datasets,
schema_version=1,
globals={
"environment": {"name": "dummy", "version": "0.0.0"},
"queue": "queue",
},
)
experiment = Experiment.new(data=data, toolchain=toolchain, label=experiment_name,)
return experiment
......@@ -46,8 +46,6 @@ Forward importing from :py:mod:`beat.backend.python.dataformat`:
"""
import copy
import six
from beat.backend.python.dataformat import DataFormat as BackendDataFormat
from beat.backend.python.dataformat import Storage # noqa
......@@ -61,8 +59,6 @@ class DataFormat(BackendDataFormat):
Parameters:
prefix (str): Establishes the prefix of your installation.
data (:py:class:`object`, Optional): The piece of data representing the
data format. It must validate against the schema defined for data
formats. If a string is passed, it is supposed to be a valid path to an
......@@ -76,13 +72,6 @@ class DataFormat(BackendDataFormat):
object that is this object's parent and the name of the field on that
object that points to this one.
dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping
dataformat names to loaded dataformats. This parameter is optional and,
if passed, may greatly speed-up data format loading times as
dataformats that are already loaded may be re-used. If you use this
parameter, you must guarantee that the cache is refreshed as
appropriate in case the underlying dataformats change.
Attributes:
name (str): The full, valid name of this dataformat
......@@ -112,10 +101,10 @@ class DataFormat(BackendDataFormat):
"""
def __init__(self, prefix, data, parent=None, dataformat_cache=None):
super(DataFormat, self).__init__(prefix, data, parent, dataformat_cache)
def __init__(self, data, parent=None):
super().__init__(data, parent)
def _load(self, data, dataformat_cache):
def _load(self, data):
"""Loads the dataformat"""
self._name = None
......@@ -134,10 +123,9 @@ class DataFormat(BackendDataFormat):
if not isinstance(data, dict): # user has passed a file pointer
# make sure to log this into the cache (avoids recursion)
dataformat_cache[data] = None
self._name = data
self.storage = Storage(self.prefix, data)
self.storage = Storage(data)
data = self.storage.json.path
if not self.storage.exists():
self.errors.append(
......@@ -167,25 +155,12 @@ class DataFormat(BackendDataFormat):
self.errors = utils.uniq(self.errors)
return
def maybe_load_format(name, obj, dataformat_cache):
def maybe_load_format(name, obj):
"""Tries to load a given dataformat from its relative path"""
if isinstance(obj, six.string_types) and obj.find("/") != -1: # load it
if isinstance(obj, str) and obj.find("/") != -1: # load it
if obj in dataformat_cache: # reuse
if dataformat_cache[obj] is None: # recursion detected
self.errors.append(
"recursion for dataformat `%s' detected" % obj
)
return self
self.referenced[obj] = dataformat_cache[obj]
else: # load it
self.referenced[obj] = DataFormat(
self.prefix, obj, (self, name), dataformat_cache
)
self.referenced[obj] = DataFormat(obj, (self, name))
if not self.referenced[obj].valid:
self.errors.append("referred dataformat `%s' is invalid" % obj)
......@@ -193,11 +168,11 @@ class DataFormat(BackendDataFormat):
return self.referenced[obj]
elif isinstance(obj, dict): # can cache it, must load from scratch
return DataFormat(self.prefix, obj, (self, name), dataformat_cache)
return DataFormat(obj, (self, name))
elif isinstance(obj, list):
retval = copy.deepcopy(obj)
retval[-1] = maybe_load_format(field, obj[-1], dataformat_cache)
retval[-1] = maybe_load_format(field, obj[-1])
return retval
return obj
......@@ -207,7 +182,7 @@ class DataFormat(BackendDataFormat):
for field, value in self.data.items():
if field in ("#description", "#schema_version"):
continue # skip the description and schema version meta attributes
self.resolved[field] = maybe_load_format(field, value, dataformat_cache)
self.resolved[field] = maybe_load_format(field, value)
if isinstance(self.resolved[field], DataFormat):
if not self.resolved[field].valid:
self.errors.append("referred dataformat `%s' is invalid" % value)
......@@ -218,7 +193,7 @@ class DataFormat(BackendDataFormat):
if "#extends" in self.resolved:
ext = self.data["#extends"]
self.referenced[ext] = maybe_load_format(self.name, ext, dataformat_cache)
self.referenced[ext] = maybe_load_format(self.name, ext)
basetype = self.resolved["#extends"]
# before updating, checks there is no name clash if basetype.valid:
......
......@@ -142,15 +142,7 @@ class BaseExecutor(object):
"""
def __init__(
self,
prefix,
data,
cache=None,
dataformat_cache=None,
database_cache=None,
algorithm_cache=None,
library_cache=None,
custom_root_folders=None,
self, prefix, data, cache=None, custom_root_folders=None,
):
# Initialisations
......@@ -178,12 +170,6 @@ class BaseExecutor(object):
else:
custom_root_folders = {}
# Temporary caches, if the user has not set them, for performance
database_cache = database_cache if database_cache is not None else {}
dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
algorithm_cache = algorithm_cache if algorithm_cache is not None else {}
library_cache = library_cache if library_cache is not None else {}
# Basic validation of the data declaration, including JSON loading if required
if not isinstance(data, dict):
if not os.path.exists(data):
......@@ -195,13 +181,7 @@ class BaseExecutor(object):
return
# Load the algorithm (using the algorithm cache if possible)
if self.data["algorithm"] in algorithm_cache:
self.algorithm = algorithm_cache[self.data["algorithm"]]
else:
self.algorithm = algorithm.Algorithm(
self.prefix, self.data["algorithm"], dataformat_cache, library_cache
)
algorithm_cache[self.algorithm.name] = self.algorithm
self.algorithm = algorithm.Algorithm[self.data["algorithm"]]
if not self.algorithm.valid:
self.errors += self.algorithm.errors
......@@ -209,36 +189,30 @@ class BaseExecutor(object):
if "loop" in self.data:
loop = self.data["loop"]
if loop["algorithm"] in algorithm_cache:
self.loop_algorithm = algorithm_cache[loop["algorithm"]]
else:
self.loop_algorithm = algorithm.Algorithm(
self.prefix, loop["algorithm"], dataformat_cache, library_cache
self.loop_algorithm = algorithm.Algorithm[loop["algorithm"]]
if len(loop["inputs"]) != len(self.loop_algorithm.input_map):
self.errors.append(
"The number of inputs of the loop algorithm doesn't correspond"
)
algorithm_cache[self.loop_algorithm.name] = self.loop_algorithm
if len(loop["inputs"]) != len(self.loop_algorithm.input_map):
for name in self.data["inputs"].keys():
if name not in self.algorithm.input_map.keys():
self.errors.append(
"The number of inputs of the loop algorithm doesn't correspond"
"The input '%s' doesn't exist in the loop algorithm" % name
)
for name in self.data["inputs"].keys():
if name not in self.algorithm.input_map.keys():
self.errors.append(
"The input '%s' doesn't exist in the loop algorithm" % name
)
if len(loop["outputs"]) != len(self.loop_algorithm.output_map):
self.errors.append(
"The number of outputs of the loop algorithm doesn't correspond"
)
if len(loop["outputs"]) != len(self.loop_algorithm.output_map):
for name in self.data["outputs"].keys():
if name not in self.algorithm.output_map.keys():
self.errors.append(
"The number of outputs of the loop algorithm doesn't correspond"
"The output '%s' doesn't exist in the loop algorithm" % name
)
for name in self.data["outputs"].keys():
if name not in self.algorithm.output_map.keys():
self.errors.append(
"The output '%s' doesn't exist in the loop algorithm" % name
)
# Check that the mapping in coherent
if len(self.data["inputs"]) != len(self.algorithm.input_map):
self.errors.append(
......@@ -276,16 +250,11 @@ class BaseExecutor(object):
return
# Load the databases (if any is required)
self._update_db_cache(
self.data["inputs"], custom_root_folders, database_cache, dataformat_cache
)
self._update_db_cache(self.data["inputs"], custom_root_folders)
if "loop" in self.data:
self._update_db_cache(
self.data["loop"]["inputs"],
custom_root_folders,
database_cache,
dataformat_cache,
self.data["loop"]["inputs"], custom_root_folders,
)
def __enter__(self):
......@@ -316,9 +285,7 @@ class BaseExecutor(object):
self.output_list = None
self.data_sinks = []
def _update_db_cache(
self, inputs, custom_root_folders, database_cache, dataformat_cache
):
def _update_db_cache(self, inputs, custom_root_folders):
""" Update the database cache based on the input list given"""
for name, details in inputs.items():
......@@ -326,18 +293,11 @@ class BaseExecutor(object):
if details["database"] not in self.databases:
if details["database"] in database_cache:
db = database_cache[details["database"]]
else:
db = database.Database(
self.prefix, details["database"], dataformat_cache
)
name = "database/%s" % db.name
if name in custom_root_folders:
db.data["root_folder"] = custom_root_folders[name]
db = database.Database[details["database"]]
database_cache[db.name] = db
name = "database/%s" % db.name
if name in custom_root_folders:
db.data["root_folder"] = custom_root_folders[name]
self.databases[db.name] = db
......
......@@ -154,27 +154,11 @@ class LocalExecutor(BaseExecutor):
"""
def __init__(
self,
prefix,
data,
cache=None,
dataformat_cache=None,
database_cache=None,