Commit 300c9661 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Use a global prefix to keep things around, add an iris test

parent b756cd0c
Pipeline #42048 failed with stage
in 3 minutes and 31 seconds
...@@ -50,8 +50,8 @@ import os ...@@ -50,8 +50,8 @@ import os
import numpy import numpy
import pkg_resources import pkg_resources
import simplejson as json import simplejson as json
import six
from beat.backend.python import config
from beat.backend.python.algorithm import Algorithm as BackendAlgorithm from beat.backend.python.algorithm import Algorithm as BackendAlgorithm
from beat.backend.python.algorithm import Runner # noqa from beat.backend.python.algorithm import Runner # noqa
from beat.backend.python.algorithm import Storage from beat.backend.python.algorithm import Storage
...@@ -104,8 +104,6 @@ class Algorithm(BackendAlgorithm): ...@@ -104,8 +104,6 @@ class Algorithm(BackendAlgorithm):
Parameters: Parameters:
prefix (str): Establishes the prefix of your installation.
data (:py:class:`object`, Optional): The piece of data representing the data (:py:class:`object`, Optional): The piece of data representing the
algorithm. It must validate against the schema defined for algorithms. algorithm. It must validate against the schema defined for algorithms.
If a string is passed, it is supposed to be a valid path to an If a string is passed, it is supposed to be a valid path to an
...@@ -115,17 +113,6 @@ class Algorithm(BackendAlgorithm): ...@@ -115,17 +113,6 @@ class Algorithm(BackendAlgorithm):
its source format or as a binary blob). If ``None`` is passed, loads its source format or as a binary blob). If ``None`` is passed, loads
our default prototype for algorithms (source code will be in Python). our default prototype for algorithms (source code will be in Python).
dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping
dataformat names to loaded dataformats. This parameter is optional and,
if passed, may greatly speed-up algorithm loading times as dataformats
that are already loaded may be re-used.
library_cache (:py:class:`dict`, Optional): A dictionary mapping library
names to loaded libraries. This parameter is optional and, if passed,
may greatly speed-up library loading times as libraries that are
already loaded may be re-used.
Attributes: Attributes:
name (str): The algorithm name name (str): The algorithm name
...@@ -183,10 +170,7 @@ class Algorithm(BackendAlgorithm): ...@@ -183,10 +170,7 @@ class Algorithm(BackendAlgorithm):
dataformat_klass = dataformat.DataFormat dataformat_klass = dataformat.DataFormat
def __init__(self, prefix, data, dataformat_cache=None, library_cache=None): def _load(self, data):
super(Algorithm, self).__init__(prefix, data, dataformat_cache, library_cache)
def _load(self, data, dataformat_cache, library_cache):
"""Loads the algorithm""" """Loads the algorithm"""
self.errors = [] self.errors = []
...@@ -208,10 +192,10 @@ class Algorithm(BackendAlgorithm): ...@@ -208,10 +192,10 @@ class Algorithm(BackendAlgorithm):
data, code = data # break down into two components data, code = data # break down into two components
if isinstance(data, six.string_types): # user has passed a file pointer if isinstance(data, str): # user has passed a file pointer
self._name = data self._name = data
self.storage = Storage(self.prefix, self._name) self.storage = Storage(self._name)
if not self.storage.json.exists(): if not self.storage.json.exists():
self.errors.append("Algorithm declaration file not found: %s" % data) self.errors.append("Algorithm declaration file not found: %s" % data)
return return
...@@ -220,7 +204,7 @@ class Algorithm(BackendAlgorithm): ...@@ -220,7 +204,7 @@ class Algorithm(BackendAlgorithm):
# At this point, `data' can be a dictionary or ``None`` # At this point, `data' can be a dictionary or ``None``
if data is None: # loads the default declaration for an algorithm if data is None: # loads the default declaration for an algorithm
algorithm_data = load_algorithm_prototype(self.prefix) algorithm_data = load_algorithm_prototype(config.get_config()["prefix"])
self.data, self.errors = schema.validate("algorithm", algorithm_data) self.data, self.errors = schema.validate("algorithm", algorithm_data)
assert not self.errors, "\n * %s" % "\n *".join(self.errors) # nosec assert not self.errors, "\n * %s" % "\n *".join(self.errors) # nosec
else: # just assign it else: # just assign it
...@@ -272,11 +256,11 @@ class Algorithm(BackendAlgorithm): ...@@ -272,11 +256,11 @@ class Algorithm(BackendAlgorithm):
[(k, v["type"]) for g in self.groups for k, v in g.get("loop", {}).items()] [(k, v["type"]) for g in self.groups for k, v in g.get("loop", {}).items()]
) )
self._validate_required_dataformats(dataformat_cache) self._validate_required_dataformats()
self._convert_parameter_types() self._convert_parameter_types()
# finally, the libraries # finally, the libraries
self._validate_required_libraries(library_cache) self._validate_required_libraries()
self._check_language_consistence() self._check_language_consistence()
def _check_endpoint_uniqueness(self): def _check_endpoint_uniqueness(self):
...@@ -318,26 +302,26 @@ class Algorithm(BackendAlgorithm): ...@@ -318,26 +302,26 @@ class Algorithm(BackendAlgorithm):
) )
) )
def _validate_dataformats(self, group, group_name, dataformat_cache): def _validate_dataformats(self, group, group_name):
for name, entry in group[group_name].items(): for name, entry in group[group_name].items():
type_name = entry["type"] type_name = entry["type"]
thisformat = self._update_dataformat_cache(type_name, dataformat_cache) thisformat = self._update_dataformat_cache(type_name)
self._validate_format(type_name, group_name, name, thisformat) self._validate_format(type_name, group_name, name, thisformat)
def _validate_required_dataformats(self, dataformat_cache): def _validate_required_dataformats(self):
"""Makes sure we can load all requested formats """Makes sure we can load all requested formats
""" """
for group in self.groups: for group in self.groups:
for name, input_ in group["inputs"].items(): for name, input_ in group["inputs"].items():
self._validate_dataformats(group, "inputs", dataformat_cache) self._validate_dataformats(group, "inputs")
if "outputs" in group: if "outputs" in group:
self._validate_dataformats(group, "outputs", dataformat_cache) self._validate_dataformats(group, "outputs")
if "loop" in group: if "loop" in group:
self._validate_dataformats(group, "loop", dataformat_cache) self._validate_dataformats(group, "loop")
if self.results: if self.results:
...@@ -346,9 +330,7 @@ class Algorithm(BackendAlgorithm): ...@@ -346,9 +330,7 @@ class Algorithm(BackendAlgorithm):
# results can only contain base types and plots therefore, only # results can only contain base types and plots therefore, only
# process plots # process plots
if result_type.find("/") != -1: if result_type.find("/") != -1:
thisformat = self._update_dataformat_cache( thisformat = self._update_dataformat_cache(result_type)
result_type, dataformat_cache
)
self._validate_format(result_type, "result", name, thisformat) self._validate_format(result_type, "result", name, thisformat)
def _convert_parameter_types(self): def _convert_parameter_types(self):
...@@ -430,7 +412,7 @@ class Algorithm(BackendAlgorithm): ...@@ -430,7 +412,7 @@ class Algorithm(BackendAlgorithm):
) )
) )
def _validate_required_libraries(self, library_cache): def _validate_required_libraries(self):
# all used libraries must be loadable; cannot use self as a library # all used libraries must be loadable; cannot use self as a library
...@@ -438,9 +420,7 @@ class Algorithm(BackendAlgorithm): ...@@ -438,9 +420,7 @@ class Algorithm(BackendAlgorithm):
for name, value in self.uses.items(): for name, value in self.uses.items():
self.libraries[value] = library_cache.setdefault( self.libraries[value] = library.Library[value]
value, library.Library(self.prefix, value, library_cache)
)
if not self.libraries[value].valid: if not self.libraries[value].valid:
self.errors.append( self.errors.append(
......
from collections import namedtuple
from .experiment import Experiment
from .toolchain import Toolchain
_Node = namedtuple("Node", ["block", "name", "dataformat"])
class _Edge:
"""An edge of a BEAT experiment/toolchain graph
"""
def __init__(self, block, name, dataformat, channel, **kwargs):
super().__init__(**kwargs)
self.before = _Node(block, name, dataformat)
self.channel = channel
def connect_to(self, block, name, dataformat):
after = _Node(block, name, dataformat)
if dataformat != self.before[-1]:
raise RuntimeError(f"Cannot connect the output of {self.before} to {after}")
self.after = after
def __repr__(self):
return f"{self.before} -> {self.after} in {self.channel}"
class _Block:
"""A block of a BEAT experiment/toolchain.
It holds information about its input/output connections.
"""
def __init__(self, component, name=None, **kwargs):
super().__init__(**kwargs)
self._component = component
self._edges = []
self._name = name or component.name.replace("/", "_")
class DatasetBlock(_Block):
def __init__(self, database, protocol, set, name=None, **kwargs):
super().__init__(component=database, name=name, **kwargs)
if protocol not in database.protocol_names:
raise ValueError(
f"Unknown protocol name ({protocol}) for database: {database}"
)
self._protocol = protocol
if set not in database.set_names(protocol):
raise ValueError(
f"Unknown set name ({set}) for "
f"database: {database} and protocol: {protocol}"
)
self._set = set
self._channel = f"{self._component.name}_{self._protocol}_{self._set}"
self._channel = self._channel.replace("/", "_")
if name is None:
self._name = self._channel
def __getattr__(self, name):
# check if name is in the set outputs
_set = self._component.set(self._protocol, self._set)
if name not in _set["outputs"]:
raise AttributeError
edge = _Edge(
block=self,
name=name,
dataformat=_set["outputs"][name],
channel=self._channel,
)
# keep the edge
setattr(self, name, edge)
self._edges.append(edge)
return edge
class _AlgorithmAnalyzerBlock(_Block):
def validate_inputs(self, **kwargs):
input_names = set(kwargs.keys())
valid_input_names = set(self._component.input_map.keys())
if input_names != valid_input_names:
raise ValueError(
f"Inputs {input_names} do not match "
f"algorithm's inputs: {valid_input_names}"
)
# set channel based on inputs if not set already
if self.channel is None:
edge_channels = set(v.channel for v in kwargs.values())
if not len(edge_channels) == 1:
raise ValueError(
f"The inputs are coming from more than 1 synchronization channel "
f"({edge_channels}). During the initialization of {self}, specify "
f"the dataset responsible for the main synchronization channel "
"of this block."
)
self.channel = edge_channels.pop()
def connect_inputs(self, **kwargs):
self.input_edges = []
# connect input edges to inputs
for input_name, edge in kwargs.items():
edge.connect_to(self, input_name, self._component.input_map[input_name])
self.input_edges.append(edge)
class AlgorithmBlock(_AlgorithmAnalyzerBlock):
def __init__(self, algorithm, sync_with=None, parameters=None, name=None, **kwargs):
super().__init__(component=algorithm, name=name, **kwargs)
self.channel = None if sync_with is None else sync_with._channel
self.parameters = parameters
def __call__(self, **kwargs):
self.validate_inputs(**kwargs)
self.connect_inputs(**kwargs)
# find the channel of output through inputs
grp = self._component.output_group
for input_name, edge in kwargs.items():
if input_name in grp["inputs"]:
output_channel = edge.channel
break
# construct output edges
output_edges = []
for output_name, dataformat in self._component.output_map.items():
edge = _Edge(
block=self,
name=output_name,
dataformat=dataformat,
channel=output_channel,
)
output_edges.append(edge)
self._edges.extend(output_edges)
if len(output_edges) == 1:
return output_edges[0]
return output_edges
class AnalyzerBlock(_AlgorithmAnalyzerBlock):
def __init__(
self, analyzer, name=None, toolchain_name=None, experiment_name=None, **kwargs
):
super().__init__(component=analyzer, name=name, **kwargs)
self.channel = None
self.toolchain_name = toolchain_name
self.experiment_name = experiment_name
def __call__(self, **kwargs):
self.validate_inputs(**kwargs)
self.connect_inputs(**kwargs)
def create_experiment(*analyzers, experiment_name, toolchain_name=None):
# construct the toolchain and experiment
toolchain_name = toolchain_name or experiment_name
experiment = _create_experiment(
[edge for analyzer in analyzers for edge in analyzer.input_edges],
toolchain_name=toolchain_name,
experiment_name=experiment_name,
)
return experiment
def _get_all_edges(edges, all_edges=None):
if all_edges is None:
all_edges = set()
for edge in edges:
all_edges.add(edge)
new_edges = [
e for e in getattr(edge.before.block, "input_edges", []) if e not in edges
]
all_edges.update(_get_all_edges(new_edges, all_edges))
return all_edges
def _create_experiment(edges, toolchain_name, experiment_name):
all_edges = list(_get_all_edges(edges))
# find all blocks
datasets = set()
algorithms = set()
analyzers = set()
for edge in all_edges:
before_block = edge.before.block
if isinstance(before_block, DatasetBlock):
datasets.add(before_block)
after_block = edge.after.block
if isinstance(after_block, AlgorithmBlock):
algorithms.add(after_block)
elif isinstance(after_block, AnalyzerBlock):
analyzers.add(after_block)
# create unique names for all blocks
all_blocks = list(datasets) + list(algorithms) + list(analyzers)
names, names_map = list(), dict()
for block in all_blocks:
name = block._name
# if the name already exists
while name in names:
name += "_2"
names.append(name)
names_map[block] = name
# find all connections
connections = []
for edge in all_edges:
before_name = names_map[edge.before.block]
after_name = names_map[edge.after.block]
connections.append(
{
"channel": edge.channel,
"from": f"{before_name}.{edge.before.name}",
"to": f"{after_name}.{edge.after.name}",
}
)
# create datasets
toolchain_datasets, experiment_datasets = list(), dict()
for block in datasets:
name = names_map[block]
outputs = [edge.before.name for edge in block._edges]
toolchain_datasets.append(dict(name=name, outputs=outputs))
experiment_datasets[name] = dict(
database=block._component.name, protocol=block._protocol, set=block._set,
)
# create blocks (algorithms)
toolchain_blocks, experiment_blocks = list(), dict()
for block in algorithms:
name = names_map[block]
channel = block.channel
inputs = [edge.after.name for edge in block.input_edges]
outputs = [edge.before.name for edge in block._edges]
toolchain_blocks.append(
dict(
name=name, synchronized_channel=channel, inputs=inputs, outputs=outputs
)
)
experiment_blocks[name] = dict(
algorithm=block._component.name,
inputs={edge.after.name: edge.before.name for edge in block.input_edges},
outputs={edge.before.name: edge.after.name for edge in block._edges},
)
# create analyzers
toolchain_analyzers, experiment_analyzers = list(), dict()
for block in analyzers:
name = names_map[block]
channel = block.channel
inputs = [edge.after.name for edge in block.input_edges]
toolchain_analyzers.append(
dict(name=name, synchronized_channel=channel, inputs=inputs)
)
experiment_analyzers[name] = dict(
algorithm=block._component.name,
inputs={edge.after.name: edge.before.name for edge in block.input_edges},
)
# create toolchain
data = dict(
analyzers=toolchain_analyzers,
blocks=toolchain_blocks,
datasets=toolchain_datasets,
connections=connections,
representation={"blocks": {}, "channel_colors": {}, "connections": {}},
)
toolchain = Toolchain.new(data=data, name=toolchain_name)
data = dict(
analyzers=experiment_analyzers,
blocks=experiment_blocks,
datasets=experiment_datasets,
schema_version=1,
globals={
"environment": {"name": "dummy", "version": "0.0.0"},
"queue": "queue",
},
)
experiment = Experiment.new(data=data, toolchain=toolchain, label=experiment_name,)
return experiment
...@@ -46,8 +46,6 @@ Forward importing from :py:mod:`beat.backend.python.dataformat`: ...@@ -46,8 +46,6 @@ Forward importing from :py:mod:`beat.backend.python.dataformat`:
""" """
import copy import copy
import six
from beat.backend.python.dataformat import DataFormat as BackendDataFormat from beat.backend.python.dataformat import DataFormat as BackendDataFormat
from beat.backend.python.dataformat import Storage # noqa from beat.backend.python.dataformat import Storage # noqa
...@@ -61,8 +59,6 @@ class DataFormat(BackendDataFormat): ...@@ -61,8 +59,6 @@ class DataFormat(BackendDataFormat):
Parameters: Parameters:
prefix (str): Establishes the prefix of your installation.
data (:py:class:`object`, Optional): The piece of data representing the data (:py:class:`object`, Optional): The piece of data representing the
data format. It must validate against the schema defined for data data format. It must validate against the schema defined for data
formats. If a string is passed, it is supposed to be a valid path to an formats. If a string is passed, it is supposed to be a valid path to an
...@@ -76,13 +72,6 @@ class DataFormat(BackendDataFormat): ...@@ -76,13 +72,6 @@ class DataFormat(BackendDataFormat):
object that is this object's parent and the name of the field on that object that is this object's parent and the name of the field on that
object that points to this one. object that points to this one.
dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping
dataformat names to loaded dataformats. This parameter is optional and,
if passed, may greatly speed-up data format loading times as
dataformats that are already loaded may be re-used. If you use this
parameter, you must guarantee that the cache is refreshed as
appropriate in case the underlying dataformats change.
Attributes: Attributes:
name (str): The full, valid name of this dataformat name (str): The full, valid name of this dataformat
...@@ -112,10 +101,10 @@ class DataFormat(BackendDataFormat): ...@@ -112,10 +101,10 @@ class DataFormat(BackendDataFormat):
""" """
def __init__(self, prefix, data, parent=None, dataformat_cache=None): def __init__(self, data, parent=None):
super(DataFormat, self).__init__(prefix, data, parent, dataformat_cache) super().__init__(data, parent)
def _load(self, data, dataformat_cache): def _load(self, data):
"""Loads the dataformat""" """Loads the dataformat"""
self._name = None self._name = None
...@@ -134,10 +123,9 @@ class DataFormat(BackendDataFormat): ...@@ -134,10 +123,9 @@ class DataFormat(BackendDataFormat):
if not isinstance(data, dict): # user has passed a file pointer if not isinstance(data, dict): # user has passed a file pointer
# make sure to log this into the cache (avoids recursion) # make sure to log this into the cache (avoids recursion)
dataformat_cache[data] = None
self._name = data self._name = data
self.storage = Storage(self.prefix, data) self.storage = Storage(data)
data = self.storage.json.path data = self.storage.json.path
if not self.storage.exists(): if not self.storage.exists():
self.errors.append( self.errors.append(
...@@ -167,25 +155,12 @@ class DataFormat(BackendDataFormat): ...@@ -167,25 +155,12 @@ class DataFormat(BackendDataFormat):
self.errors = utils.uniq(self.errors) self.errors = utils.uniq(self.errors)
return return
def maybe_load_format(name, obj, dataformat_cache): def maybe_load_format(name, obj):
"""Tries to load a given dataformat from its relative path""" """Tries to load a given dataformat from its relative path"""
if isinstance(obj, six.string_types) and obj.find("/") != -1: # load it if isinstance(obj, str) and obj.find("/") != -1: # load it
if obj in dataformat_cache: # reuse self.referenced[obj] = DataFormat(obj, (self, name))
if dataformat_cache[obj] is None: # recursion detected
self.errors.append(
"recursion for dataformat `%s' detected" % obj
)
return self
self.referenced[obj] = dataformat_cache[obj]
else: # load it
self.referenced[obj] = DataFormat(
self.prefix, obj, (self, name), dataformat_cache
)
if not self.referenced[obj].valid: if not self.referenced[obj].valid:
self.errors.append("referred dataformat `%s' is invalid" % obj) self.errors.append("referred dataformat `%s' is invalid" % obj)
...@@ -193,11 +168,11 @@ class DataFormat(BackendDataFormat): ...@@ -193,11 +168,11 @@ class DataFormat(BackendDataFormat):
return self.referenced[obj] return self.referenced[obj]
elif isinstance(obj, dict): # can cache it, must load from scratch elif isinstance(obj, dict): # can cache it, must load from scratch
return DataFormat(self.prefix, obj, (self, name), dataformat_cache) return DataFormat(obj, (self, name))
elif isinstance(obj, list): elif isinstance(obj, list):
retval = copy.deepcopy(obj) retval = copy.deepcopy(obj)
retval[-1] = maybe_load_format(field, obj[-1], dataformat_cache) retval[-1] = maybe_load_format(field, obj[-1])
return retval return retval
return obj return obj
...@@ -207,7 +182,7 @@ class DataFormat(BackendDataFormat): ...@@ -207,7 +182,7 @@ class DataFormat(BackendDataFormat):
for field, value in self.data.items(): for field, value in self.data.items():
if field in ("#description", "#schema_version"): if field in ("#description", "#schema_version"):
continue # skip the description and schema version meta attributes continue # skip the description and schema version meta attributes
self.resolved[field] = maybe_load_format(field, value, dataformat_cache) self.resolved[field] = maybe_load_format(field, value)
if isinstance(self.resolved[field], DataFormat): if isinstance(self.resolved[field], DataFormat):
if not self.resolved[field].valid: if not self.resolved[field].valid:
self.errors.append("referred dataformat `%s' is invalid" % value) self.errors.append("referred dataformat `%s' is invalid" % value)
...@@ -218,7 +193,7 @@ class DataFormat(BackendDataFormat): ...@@ -218,7 +193,7 @@ class DataFormat(BackendDataFormat):
if "#extends" in self.resolved: if "#extends" in self.resolved:
ext = self.data["#extends"] ext = self.data["#extends"]
self.referenced[ext] = maybe_load_format(self.name, ext, dataformat_cache) self.referenced[ext] = maybe_load_format(self.name, ext)
basetype = self.resolved["#extends"] basetype = self.resolved["#extends"]
# before updating, checks there is no name clash if basetype.valid: # before updating, checks there is no name clash if basetype.valid:
......
...@@ -142,15 +142,7 @@ class BaseExecutor(object): ...@@ -142,15 +142,7 @@ class BaseExecutor(object):
""" """
def __init__( def __init__(
self, self, prefix, data, cache=None, custom_root_folders=None,
prefix,
data,
cache=None,
dataformat_cache=None,
database_cache=None,
algorithm_cache=None,
library_cache=None,
custom_root_folders=None,
): ):
# Initialisations # Initialisations
...@@ -178,12 +170,6 @@ class BaseExecutor(object): ...@@ -178,12 +170,6 @@ class BaseExecutor(object):
else: else:
custom_root_folders = {} custom_root_folders = {}
# Temporary caches, if the user has not set them, for performance
database_cache = database_cache if database_cache is not None else {}
dataformat_cache = dataformat_cache if dataformat_cache is not None else {}
algorithm_cache = algorithm_cache if algorithm_cache is not None else {}
library_cache = library_cache if library_cache is not None else {}
# Basic validation of the data declaration, including JSON loading if required # Basic validation of the data declaration, including JSON loading if required
if not isinstance(data, dict): if not isinstance(data, dict):
if not os.path.exists(data): if not os.path.exists(data):
...@@ -195,13 +181,7 @@ class BaseExecutor(object): ...@@ -195,13 +181,7 @@ class BaseExecutor(object):
return return
# Load the algorithm (using the algorithm cache if possible) # Load the algorithm (using the algorithm cache if possible)
if self.data["algorithm"] in algorithm_cache: self.algorithm = algorithm.Algorithm[self.data["algorithm"]]
self.algorithm = algorithm_cache[self.data["algorithm"]]
else:
self.algorithm = algorithm.Algorithm(
self.prefix, self.data["algorithm"], dataformat_cache, library_cache
)
algorithm_cache[self.algorithm.name] = self.algorithm