From 73491015355e3d19608c1bf369f602bb9e136c1a Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Mon, 10 Aug 2020 19:14:31 +0200 Subject: [PATCH] Integrate deeper with config and prefix --- .pre-commit-config.yaml | 1 - beat/backend/python/algorithm.py | 4 + beat/backend/python/baseformat.py | 2 + beat/backend/python/config.py | 193 +++++++++++++++++++ beat/backend/python/database.py | 137 +++++++------ beat/backend/python/dataformat.py | 148 +++++++++----- beat/backend/python/protocoltemplate.py | 122 +++++++----- beat/backend/python/test/test_interactive.py | 124 ++++++++++++ beat/backend/python/utils.py | 24 --- 9 files changed, 582 insertions(+), 173 deletions(-) create mode 100644 beat/backend/python/config.py create mode 100644 beat/backend/python/test/test_interactive.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b8be7b7..b4a3432 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,6 @@ repos: - id: debug-statements exclude: beat/backend/python/test/prefix/.*/syntax_error - id: check-added-large-files - - id: check-docstring-first - id: flake8 exclude: beat/backend/python/test/prefix/.*/(.*crash|syntax_error) - id: check-yaml diff --git a/beat/backend/python/algorithm.py b/beat/backend/python/algorithm.py index 2190b40..08b8128 100644 --- a/beat/backend/python/algorithm.py +++ b/beat/backend/python/algorithm.py @@ -1056,3 +1056,7 @@ class Algorithm(object): k.export(prefix) self.write(Storage(prefix, self.name, self.language)) + + +class Analyzer(Algorithm): + """docstring for Analyzer""" diff --git a/beat/backend/python/baseformat.py b/beat/backend/python/baseformat.py index 9e8267f..9626fd2 100644 --- a/beat/backend/python/baseformat.py +++ b/beat/backend/python/baseformat.py @@ -123,6 +123,8 @@ def setup_scalar(formatname, attrname, dtype, value, casting, add_defaults): return str(value) else: # it is a dataformat + if isinstance(value, baseformat): + return value return dtype().from_dict(value, casting=casting, add_defaults=add_defaults) diff --git a/beat/backend/python/config.py b/beat/backend/python/config.py new file mode 100644 index 0000000..0a04f6c --- /dev/null +++ b/beat/backend/python/config.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +################################################################################### +# # +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ # +# Contact: beat.support@idiap.ch # +# # +# Redistribution and use in source and binary forms, with or without # +# modification, are permitted provided that the following conditions are met: # +# # +# 1. Redistributions of source code must retain the above copyright notice, this # +# list of conditions and the following disclaimer. # +# # +# 2. Redistributions in binary form must reproduce the above copyright notice, # +# this list of conditions and the following disclaimer in the documentation # +# and/or other materials provided with the distribution. # +# # +# 3. Neither the name of the copyright holder nor the names of its contributors # +# may be used to endorse or promote products derived from this software without # +# specific prior written permission. # +# # +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # +# # +################################################################################### + + +"""Global Configuration state and management""" + + +import getpass +import logging +import os + +from contextlib import contextmanager + +logger = logging.getLogger(__name__) + + +DEFAULTS = { + "user": getpass.getuser(), + "prefix": os.path.realpath(os.path.join(os.curdir, "prefix")), + "cache": "cache", +} +"""Default values of the configuration""" + + +DOC = { + "user": "User name for operations that create, delete or edit objects", + "prefix": "Directory containing BEAT objects", + "cache": "Directory to use for data caching (relative to prefix)", +} +"""Documentation for configuration parameters""" + +_global_config = DEFAULTS.copy() + + +def get_config(): + """Retrieve current values for configuration set by :func:`set_config` + + Returns + ------- + config : dict + Keys are parameter names that can be passed to :func:`set_config`. + + See Also + -------- + config_context: Context manager for global configuration + set_config: Set global configuration + """ + return _global_config.copy() + + +def set_config(**kwargs): + """Set global configuration + + Parameters + ---------- + user : str, optional + The username used when creating objects + prefix : str, optional + The path to the current prefix + cache : str, optional + The path to the current cache + + See Also + -------- + config_context: Context manager for global configuration + get_config: Retrieve current values of the global configuration + """ + supported_keys = set(DEFAULTS.keys()) + set_keys = set(kwargs.keys()) + if set_keys not in supported_keys: + raise ValueError( + f"Only {supported_keys} are valid configurations. " + f"Got these extra values: {set_keys - supported_keys}" + ) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + _global_config.update(kwargs) + # if a new prefix path is set, clear the prefix + if "prefix" in kwargs: + Prefix().clear() + + +@contextmanager +def config_context(**new_config): + """Context manager for global configuration + + Parameters + ---------- + user : str, optional + The username used when creating objects + prefix : str, optional + The path to the current prefix + cache : str, optional + The path to the current cache + + Notes + ----- + All settings, not just those presently modified, will be returned to + their previous values when the context manager is exited. This is not + thread-safe. + + See Also + -------- + set_config: Set global configuration + get_config: Retrieve current values of the global configuration + """ + old_config = get_config().copy() + # also backup prefix + old_prefix = Prefix().copy() + set_config(**new_config) + + try: + yield + finally: + set_config(**old_config) + prefix = Prefix() + prefix.clear() + prefix.update(old_prefix) + + +# ---------------------------------------------------------- + + +class Singleton(type): + """A Singleton metaclass + The singleton class calls the __init__ method each time the instance is requested. + From: https://stackoverflow.com/a/6798042/1286165 + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + return cls._instances[cls] + + +class Prefix(dict, metaclass=Singleton): + def __init__(self, path=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class PrefixMeta(type): + def __contains__(cls, key): + return f"{cls.asset_folder}/{key}" in Prefix() + + def __getitem__(cls, key): + folder = f"{cls.asset_folder}/{key}" + prefix = Prefix() + + # if cached, return it + if folder in prefix: + return prefix[folder] + + # otherwise, load from prefix + asset = super().__call__(key) + return asset + + def __setitem__(cls, key, value): + folder = f"{cls.asset_folder}/{key}" + prefix = Prefix() + prefix[folder] = value diff --git a/beat/backend/python/database.py b/beat/backend/python/database.py index 1f5ce1d..2383c06 100644 --- a/beat/backend/python/database.py +++ b/beat/backend/python/database.py @@ -52,6 +52,7 @@ import numpy as np import simplejson as json import six +from . import config from . import loader from . import utils from .dataformat import DataFormat @@ -67,8 +68,6 @@ class Storage(utils.CodeStorage): Parameters: - prefix (str): Establishes the prefix of your installation. - name (str): The name of the database object in the format ``<name>/<version>``. @@ -77,16 +76,17 @@ class Storage(utils.CodeStorage): asset_type = "database" asset_folder = "databases" - def __init__(self, prefix, name): + def __init__(self, name): if name.count("/") != 1: raise RuntimeError("invalid database name: `%s'" % name) self.name, self.version = name.split("/") self.fullname = name - self.prefix = prefix - path = os.path.join(self.prefix, self.asset_folder, name + ".json") + path = os.path.join( + config.get_config()["prefix"], self.asset_folder, name + ".json" + ) path = path[:-5] # views are coded in Python super(Storage, self).__init__(path, "python") @@ -105,8 +105,6 @@ class Runner(object): module (:std:term:`module`): The preloaded module containing the database views as returned by :py:func:`.loader.load_module`. - prefix (str): Establishes the prefix of your installation. - root_folder (str): The path pointing to the root folder of this database exc (:std:term:`class`): The class to use as base exception when @@ -119,7 +117,7 @@ class Runner(object): """ - def __init__(self, module, definition, prefix, root_folder, exc=None): + def __init__(self, module, definition, root_folder, exc=None): try: class_ = getattr(module, definition["view"]) @@ -132,7 +130,6 @@ class Runner(object): self.obj = loader.run(class_, "__new__", exc) self.ready = False - self.prefix = prefix self.root_folder = root_folder self.definition = definition self.exc = exc or RuntimeError @@ -193,7 +190,7 @@ class Runner(object): self, output_name, output_format, - self.prefix, + config.get_config()["prefix"], start_index=start_index, end_index=end_index, pack=pack, @@ -226,18 +223,8 @@ class Database(object): Parameters: - prefix (str): Establishes the prefix of your installation. - name (str): The fully qualified database name (e.g. ``db/1``) - dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping - dataformat names to loaded dataformats. This parameter is optional and, - if passed, may greatly speed-up database loading times as dataformats - that are already loaded may be re-used. If you use this parameter, you - must guarantee that the cache is refreshed as appropriate in case the - underlying dataformats change. - - Attributes: name (str): The full, valid name of this database @@ -247,58 +234,96 @@ class Database(object): """ - def __init__(self, prefix, name, dataformat_cache=None): - + def _init(self): self._name = None - self.prefix = prefix self.dataformats = {} # preloaded dataformats self.storage = None - self.errors = [] self.data = None - # if the user has not provided a cache, still use one for performance - dataformat_cache = dataformat_cache if dataformat_cache is not None else {} + def __init__(self, name): + + self._init() + self._load(name) + + @classmethod + def new( + cls, + code_path, + protocols, + name, + description=None, + schema_version=2, + root_folder="/foo/bar", + ): + self = cls.__new__(cls) + self._init() + + if not name: + raise ValueError(f"Invalid {name}. The name should be a non-empty string!") - self._load(name, dataformat_cache) + if "/" not in name: + name = f"{name}/1" - def _update_dataformat_cache(self, outputs, dataformat_cache): + self._name = name + + def protocoltemplate_name(v): + if hasattr(v, "name"): + v = v.name + return v + + for i, proto in enumerate(protocols): + protocols[i]["template"] = protocoltemplate_name(proto["template"]) + + data = dict(protocols=protocols) + if description is not None: + data["description"] = description + if schema_version is not None: + data["schema_version"] = schema_version + + self.data = data + + self.storage = Storage(name) + # save the code into storage + with open(code_path, "rt") as f: + self.storage.code.save(f.read()) + self.code_path = self.storage.code.path + self.code = self.storage.code.load() + + self._load_v2() + + return self + + def _update_dataformat_cache(self, outputs): for key, value in outputs.items(): if value in self.dataformats: continue - if value in dataformat_cache: - dataformat = dataformat_cache[value] - else: - dataformat = DataFormat(self.prefix, value) - dataformat_cache[value] = dataformat - + dataformat = DataFormat[value] self.dataformats[value] = dataformat - def _load_v1(self, dataformat_cache): + def _load_v1(self): """Loads a v1 database and fills the dataformat cache""" for protocol in self.data["protocols"]: for set_ in protocol["sets"]: - self._update_dataformat_cache(set_["outputs"], dataformat_cache) + self._update_dataformat_cache(set_["outputs"]) - def _load_v2(self, dataformat_cache): + def _load_v2(self): """Loads a v2 database and fills the dataformat cache""" for protocol in self.data["protocols"]: - protocol_template = ProtocolTemplate( - self.prefix, protocol["template"], dataformat_cache - ) + protocol_template = ProtocolTemplate[protocol["template"]] for set_ in protocol_template.sets(): - self._update_dataformat_cache(set_["outputs"], dataformat_cache) + self._update_dataformat_cache(set_["outputs"]) - def _load(self, data, dataformat_cache): + def _load(self, data): """Loads the database""" self._name = data - self.storage = Storage(self.prefix, self._name) + self.storage = Storage(self._name) json_path = self.storage.json.path if not self.storage.json.exists(): self.errors.append("Database declaration file not found: %s" % json_path) @@ -318,9 +343,9 @@ class Database(object): self.code = self.storage.code.load() if self.schema_version == 1: - self._load_v1(dataformat_cache) + self._load_v1() elif self.schema_version == 2: - self._load_v2(dataformat_cache) + self._load_v2() else: raise RuntimeError( "Invalid schema version {schema_version}".format( @@ -337,7 +362,7 @@ class Database(object): @name.setter def name(self, value): self._name = value - self.storage = Storage(self.prefix, value) + self.storage = Storage(value) @property def description(self): @@ -423,7 +448,7 @@ class Database(object): data = self.protocol(protocol)["sets"] else: protocol = self.protocol(protocol) - protocol_template = ProtocolTemplate(self.prefix, protocol["template"]) + protocol_template = ProtocolTemplate[protocol["template"]] if not protocol_template.valid: raise RuntimeError( "\n * {}".format("\n * ".join(protocol_template.errors)) @@ -444,7 +469,7 @@ class Database(object): data = self.protocol(protocol)["sets"] else: protocol = self.protocol(protocol) - protocol_template = ProtocolTemplate(self.prefix, protocol["template"]) + protocol_template = ProtocolTemplate[protocol["template"]] if not protocol_template.valid: raise RuntimeError( "\n * {}".format("\n * ".join(protocol_template.errors)) @@ -470,7 +495,7 @@ class Database(object): else: protocol = self.protocol(protocol_name) template_name = protocol["template"] - protocol_template = ProtocolTemplate(self.prefix, template_name) + protocol_template = ProtocolTemplate[template_name] view_definition = protocol_template.set(set_name) view_definition["view"] = protocol["views"][set_name]["view"] parameters = protocol["views"][set_name].get("parameters") @@ -534,11 +559,7 @@ class Database(object): root_folder = self.data["root_folder"] return Runner( - self._module, - self.view_definition(protocol, name), - self.prefix, - root_folder, - exc, + self._module, self.view_definition(protocol, name), root_folder, exc, ) def json_dumps(self, indent=4): @@ -598,7 +619,7 @@ class Database(object): Raises: - RuntimeError: If prefix and self.prefix point to the same directory. + RuntimeError: If prefix and prefix point to the same directory. """ @@ -608,7 +629,7 @@ class Database(object): if not self.valid: raise RuntimeError("database is not valid") - if prefix == self.prefix: + if prefix == config.get_config()["prefix"]: raise RuntimeError( "Cannot export database to the same prefix (" "%s)" % prefix ) @@ -618,10 +639,10 @@ class Database(object): if self.schema_version != 1: for protocol in self.protocols.values(): - protocol_template = ProtocolTemplate(self.prefix, protocol["template"]) + protocol_template = ProtocolTemplate[protocol["template"]] protocol_template.export(prefix) - self.write(Storage(prefix, self.name)) + self.write(Storage(self.name)) # ---------------------------------------------------------- diff --git a/beat/backend/python/dataformat.py b/beat/backend/python/dataformat.py index fba3199..f5774ae 100644 --- a/beat/backend/python/dataformat.py +++ b/beat/backend/python/dataformat.py @@ -42,6 +42,7 @@ dataformat Validation and parsing for dataformats """ +import collections import copy import re @@ -49,9 +50,13 @@ import numpy import simplejson as json import six +from . import config from . import utils from .baseformat import baseformat +DATA_FORMAT_TYPE = "dataformat" +DATA_FORMAT_FOLDER = "dataformats" + # ---------------------------------------------------------- @@ -60,28 +65,24 @@ class Storage(utils.Storage): Parameters: - prefix (str): Establishes the prefix of your installation. - name (str): The name of the dataformat object in the format ``<user>/<name>/<version>``. """ - asset_type = "dataformat" - asset_folder = "dataformats" + asset_type = DATA_FORMAT_TYPE + asset_folder = DATA_FORMAT_FOLDER - def __init__(self, prefix, name): + def __init__(self, name): if name.count("/") != 2: raise RuntimeError("invalid dataformat name: `%s'" % name) self.username, self.name, self.version = name.split("/") self.fullname = name - self.prefix = prefix + prefix = config.get_config()["prefix"] - path = utils.hashed_or_simple( - self.prefix, self.asset_folder, name, suffix=".json" - ) + path = utils.hashed_or_simple(prefix, self.asset_folder, name, suffix=".json") path = path[:-5] super(Storage, self).__init__(path) @@ -93,7 +94,7 @@ class Storage(utils.Storage): # ---------------------------------------------------------- -class DataFormat(object): +class DataFormat(metaclass=config.PrefixMeta): """Data formats define the chunks of data that circulate between blocks. Parameters: @@ -111,13 +112,6 @@ class DataFormat(object): object that is this object's parent and the name of the field on that object that points to this one. - dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping - dataformat names to loaded dataformats. This parameter is optional and, - if passed, may greatly speed-up data format loading times as - dataformats that are already loaded may be re-used. If you use this - parameter, you must guarantee that the cache is refreshed as - appropriate in case the underlying dataformats change. - Attributes: name (str): The full, valid name of this dataformat @@ -147,28 +141,28 @@ class DataFormat(object): """ - def __init__(self, prefix, data, parent=None, dataformat_cache=None): + asset_type = DATA_FORMAT_TYPE + asset_folder = DATA_FORMAT_FOLDER + def _init(self): self._name = None self.storage = None self.resolved = None - self.prefix = prefix self.errors = [] self.data = None self.resolved = None self.referenced = {} - self.parent = parent + self.parent = None - # if the user has not provided a cache, still use one for performance - dataformat_cache = dataformat_cache if dataformat_cache is not None else {} + def __init__(self, data, parent=None): - try: - self._load(data, dataformat_cache) - finally: - if self._name is not None: # registers it into the cache, even if failed - dataformat_cache[self._name] = self + self._init() + self.parent = parent + self._load(data) + # cache in prefix + DataFormat[self.name] = self - def _load(self, data, dataformat_cache): + def _load(self, data): """Loads the dataformat""" if isinstance(data, dict): @@ -176,7 +170,7 @@ class DataFormat(object): self.data = data else: self._name = data - self.storage = Storage(self.prefix, data) + self.storage = Storage(data) json_path = self.storage.json.path if not self.storage.exists(): self.errors.append( @@ -196,7 +190,9 @@ class DataFormat(object): ) return - dataformat_cache[self._name] = self # registers itself into the cache + self._resolve() + + def _resolve(self): self.resolved = copy.deepcopy(self.data) @@ -212,31 +208,29 @@ class DataFormat(object): if is_reserved(key): del self.resolved[key] - def maybe_load_format(name, obj, dataformat_cache): + def maybe_load_format(name, obj): """Tries to load a given dataformat from its relative path""" if isinstance(obj, six.string_types) and obj.find("/") != -1: # load it - if obj in dataformat_cache: # reuse + if obj in DataFormat: # reuse - if dataformat_cache[obj] is None: # recursion detected + if DataFormat[obj] is None: # recursion detected return self - self.referenced[obj] = dataformat_cache[obj] + self.referenced[obj] = DataFormat[obj] else: # load it - self.referenced[obj] = DataFormat( - self.prefix, obj, (self, name), dataformat_cache - ) + self.referenced[obj] = DataFormat(obj, (self, name)) return self.referenced[obj] elif isinstance(obj, dict): # can cache it, must load from scratch - return DataFormat(self.prefix, obj, (self, name), dataformat_cache) + return DataFormat(obj, (self, name)) elif isinstance(obj, list): retval = copy.deepcopy(obj) - retval[-1] = maybe_load_format(field, obj[-1], dataformat_cache) + retval[-1] = maybe_load_format(field, obj[-1]) return retval return obj @@ -245,7 +239,7 @@ class DataFormat(object): for field, value in self.data.items(): if field in ("#description", "#schema_version"): continue # skip the description and schema version meta attributes - self.resolved[field] = maybe_load_format(field, value, dataformat_cache) + self.resolved[field] = maybe_load_format(field, value) # at this point, there should be no more external references in # ``self.resolved``. We treat the "#extends" property, which requires a @@ -253,13 +247,77 @@ class DataFormat(object): if "#extends" in self.resolved: ext = self.data["#extends"] - self.referenced[ext] = maybe_load_format(self._name, ext, dataformat_cache) + self.referenced[ext] = maybe_load_format(self._name, ext) basetype = self.resolved["#extends"] tmp = self.resolved self.resolved = basetype.resolved self.resolved.update(tmp) del self.resolved["#extends"] # avoids infinite recursion + @classmethod + def new( + cls, + definition, + name, + description=None, + extends=None, + schema_version=None, + parent=None, + ): + self = cls.__new__(cls) + self._init() + + def str_or_dtype_or_type(v): + # if it is a dict + if isinstance(v, collections.abc.Mapping): + return {k: str_or_dtype_or_type(v_) for k, v_ in v.items()} + # if it's a list + if isinstance(v, collections.abc.Sequence): + return [v_ for v_ in v[:-1]] + [str_or_dtype_or_type(v[-1])] + # if it is another dataformat or a numpy.dtype + elif hasattr(v, "name"): + v = v.name + # if it is a str + elif v is str: + v = "string" + # if none of the above, convert to a numpy dtype and then to its name + else: + v = numpy.dtype(v).name + return v + + data = str_or_dtype_or_type(definition) + + if description is not None: + data["#description"] = description + + if extends is not None: + data["#extends"] = extends.name + + if schema_version is not None: + data["#schema_version"] = schema_version + + self.data = data + + if not name: + raise ValueError(f"Invalid {name}. The name should be a non-empty string!") + + if name != "analysis:result" and "/" not in name: + name = f"{config.get_config()['user']}/{name}/1" + + self._name = name + + self.parent = parent + + if name != "analysis:result": + self.storage = Storage(name) + + self._resolve() + + # cache in prefix + DataFormat[self.name] = self + + return self + @property def name(self): """Returns the name of this object, either from the filename or composed @@ -273,7 +331,7 @@ class DataFormat(object): @name.setter def name(self, value): self._name = value - self.storage = Storage(self.prefix, value) + self.storage = Storage(value) @property def schema_version(self): @@ -426,7 +484,7 @@ class DataFormat(object): obj = self.type() if isinstance(data, dict): obj.from_dict(data, casting="safe", add_defaults=False) - elif isinstance(data, six.string_types): + elif isinstance(data, bytes): obj.unpack(data) else: obj.unpack_from(data) @@ -510,7 +568,7 @@ class DataFormat(object): Raises: - RuntimeError: If prefix and self.prefix point to the same directory. + RuntimeError: If prefix and prefix point to the same directory. """ @@ -520,7 +578,7 @@ class DataFormat(object): if not self.valid: raise RuntimeError("dataformat is not valid:\n{}".format(self.errors)) - if prefix == self.prefix: + if prefix == prefix: raise RuntimeError( "Cannot export dataformat to the same prefix (" "%s)" % prefix ) diff --git a/beat/backend/python/protocoltemplate.py b/beat/backend/python/protocoltemplate.py index a5ea0bd..e14b1d9 100644 --- a/beat/backend/python/protocoltemplate.py +++ b/beat/backend/python/protocoltemplate.py @@ -44,47 +44,45 @@ Validation of database protocol templates import simplejson as json +from . import config from . import utils from .dataformat import DataFormat -# ---------------------------------------------------------- +PROTOCOL_TEMPLATE_TYPE = "protocoltemplate" +PROTOCOL_TEMPLATE_FOLDER = "protocoltemplates" +# ---------------------------------------------------------- class Storage(utils.Storage): """Resolves paths for protocol templates Parameters: - prefix (str): Establishes the prefix of your installation. - name (str): The name of the protocol template object in the format ``<name>/<version>``. """ - asset_type = "protocoltemplate" - asset_folder = "protocoltemplates" + asset_type = PROTOCOL_TEMPLATE_TYPE + asset_folder = PROTOCOL_TEMPLATE_FOLDER - def __init__(self, prefix, name): + def __init__(self, name): if name.count("/") != 1: raise RuntimeError("invalid protocol template name: `%s'" % name) self.name, self.version = name.split("/") self.fullname = name - self.prefix = prefix path = utils.hashed_or_simple( - self.prefix, self.asset_folder, name, suffix=".json" + config.get_config()["prefix"], self.asset_folder, name, suffix=".json" ) path = path[:-5] super(Storage, self).__init__(path) # ---------------------------------------------------------- - - -class ProtocolTemplate(object): +class ProtocolTemplate(metaclass=config.PrefixMeta): """Protocol template define the design of the database. @@ -94,14 +92,6 @@ class ProtocolTemplate(object): name (str): The fully qualified protocol template name (e.g. ``db/1``) - dataformat_cache (:py:class:`dict`, Optional): A dictionary mapping - dataformat names to loaded dataformats. This parameter is optional and, - if passed, may greatly speed-up database loading times as dataformats - that are already loaded may be re-used. If you use this parameter, you - must guarantee that the cache is refreshed as appropriate in case the - underlying dataformats change. - - Attributes: name (str): The full, valid name of this database @@ -111,27 +101,83 @@ class ProtocolTemplate(object): """ - def __init__(self, prefix, name, dataformat_cache=None): + asset_type = PROTOCOL_TEMPLATE_TYPE + asset_folder = PROTOCOL_TEMPLATE_FOLDER + def _init(self): self._name = None - self.prefix = prefix self.dataformats = {} # preloaded dataformats self.storage = None - self.errors = [] self.data = None - # if the user has not provided a cache, still use one for performance - dataformat_cache = dataformat_cache if dataformat_cache is not None else {} + def __init__(self, name): + + self._init() + self._load(name) + + # cache in prefix + ProtocolTemplate[self.name] = self + + @classmethod + def new( + cls, sets, name, description=None, schema_version=None, + ): + self = cls.__new__(cls) + self._init() + + if not name: + raise ValueError(f"Invalid {name}. The name should be a non-empty string!") + + if "/" not in name: + name = f"{name}/1" + + self._name = name + + def dataformat_name(v): + if hasattr(v, "name"): + v = v.name + return v - self._load(name, dataformat_cache) + for i, set_ in enumerate(sets): + sets[i]["outputs"] = { + k: dataformat_name(v) for k, v in set_["outputs"].items() + } - def _load(self, data, dataformat_cache): + data = dict(sets=sets) + if description is not None: + data["description"] = description + if schema_version is not None: + data["schema_version"] = schema_version + + self.data = data + + self.storage = Storage(name) + + self._resolve() + + # cache in prefix + ProtocolTemplate[self.name] = self + + return self + + def _resolve(self): + for set_ in self.data["sets"]: + + for key, value in set_["outputs"].items(): + + if value in self.dataformats: + continue + + dataformat = DataFormat[value] + self.dataformats[value] = dataformat + + def _load(self, data): """Loads the protocol template""" self._name = data - self.storage = Storage(self.prefix, self._name) + self.storage = Storage(self._name) json_path = self.storage.json.path if not self.storage.json.exists(): self.errors.append( @@ -149,21 +195,7 @@ class ProtocolTemplate(object): "Protocol template declaration file invalid: %s" % error ) return - - for set_ in self.data["sets"]: - - for key, value in set_["outputs"].items(): - - if value in self.dataformats: - continue - - if value in dataformat_cache: - dataformat = dataformat_cache[value] - else: - dataformat = DataFormat(self.prefix, value) - dataformat_cache[value] = dataformat - - self.dataformats[value] = dataformat + self._resolve() @property def name(self): @@ -174,7 +206,7 @@ class ProtocolTemplate(object): @name.setter def name(self, value): self._name = value - self.storage = Storage(self.prefix, value) + self.storage = Storage(value) @property def description(self): @@ -285,7 +317,7 @@ class ProtocolTemplate(object): Raises: - RuntimeError: If prefix and self.prefix point to the same directory. + RuntimeError: If prefix and prefix point to the same directory. """ @@ -295,7 +327,7 @@ class ProtocolTemplate(object): if not self.valid: raise RuntimeError("protocol template is not valid") - if prefix == self.prefix: + if prefix == prefix: raise RuntimeError( "Cannot export protocol template to the same prefix (" "%s)" % prefix ) diff --git a/beat/backend/python/test/test_interactive.py b/beat/backend/python/test/test_interactive.py new file mode 100644 index 0000000..4f2a1da --- /dev/null +++ b/beat/backend/python/test/test_interactive.py @@ -0,0 +1,124 @@ +import io +import unittest + +import numpy as np + +from ..dataformat import DataFormat + +coordinates = DataFormat.new( + definition={"x": int, "y": int, "name": str}, + name="coordinates", + description="coordinates in an image", +) + +sizes = DataFormat.new( + definition={"width": int, "height": int, "array": [0, float]}, name="sizes" +) +rectangles = DataFormat.new( + definition={"coords": coordinates, "size": sizes}, name="rectangles" +) + + +class CoordinatesToSizes: + def process(self, inputs, data_loaders, outputs): + # converts coordinates to sizes + coord = inputs["coordinates"] + new_size = sizes.type().from_dict( + dict(width=coord.x, height=coord.y, array=[coord.x, coord.y]) + ) + outputs["sizes"].write(new_size) + return True + + +class DataFormatTest(unittest.TestCase): + def test_dataformat_creation(self): + def assert_hash(h, oracle): + self.assertEqual(h, oracle) + + # test hash functionality + assert_hash( + coordinates.hash(), + "196214412087efe517820a27b6442cfb9fc2a843e30bc33c38bc74db48ee2bf1", + ) + + def assert_data_roundtrip(data, data_format, oracle=None): + if oracle is None: + oracle = data + beat_type = data_format.type + beat_data = beat_type().from_dict(data) + obtained_data = beat_data.as_dict() + self.assertEqual(obtained_data, oracle) + # test validate + data_format.validate(data) + # validate from string + data_format.validate(beat_data.pack()) + # validate from file + with io.BytesIO() as fd: + beat_data.pack_into(fd) + fd.seek(0) + data_format.validate(fd) + + # test data creation + data = dict(x=1, y=2, name="test") + assert_data_roundtrip(data, coordinates) + + # test safe data conversion + data = dict(x=np.int32(1), y=np.int32(2), name="test") + assert_data_roundtrip(data, coordinates) + + # test unsafe data conversion + with self.assertRaises(TypeError): + coordinates.type().from_dict(dict(x=1.0, y=2.0, name="test")) + + # test extra data attributes + with self.assertRaises(AttributeError): + coordinates.type().from_dict(dict(x=1, y=2, name="test", extra="five")) + + # test missing data attributes + with self.assertRaises(AttributeError): + coordinates.type().from_dict(dict(x=1)) + + # test hierarchy + data = dict( + coords=dict(x=1, y=2, name="test"), size=dict(width=3, height=4, array=[1]) + ) + assert_data_roundtrip(data, rectangles) + + # test when data is already in BEAT format + data_ = dict( + coords=coordinates.type().from_dict(dict(x=1, y=2, name="test")), + size=sizes.type().from_dict( + dict(width=3, height=4, array=np.asarray([1], dtype="int32")) + ), + ) + assert_data_roundtrip(data_, rectangles, oracle=data) + + # test hash functionality + assert_hash( + sizes.hash(), + "afcc84a5ea2a6dcdcd531352691a6358d1bcbd6c28e4d41ce15328dd7e5f314e", + ) + + assert_hash( + rectangles.hash(), + "a09ebe6da8e4b812c7d1e0ba0a841e35c988280687960ca5a622ee142e690316", + ) + + # name cannot be invalid + with self.assertRaises(ValueError): + DataFormat.new({"value": int}, name=None) + + with self.assertRaises(ValueError): + DataFormat.new({"value": int}, name="") + + +# class AlgorithmTest(unittest.TestCase): +# """docstring for AlgorithmTest""" + +# def test_algorithm_creation(self): +# alg = CoordinatesToSizes() +# algorithm = Algorithm_( +# alg, +# {"main": {"inputs": ["coordinates"], "outputs": ["sizes"]}}, +# type="sequential", +# ) diff --git a/beat/backend/python/utils.py b/beat/backend/python/utils.py index 3c74478..4468a3f 100644 --- a/beat/backend/python/utils.py +++ b/beat/backend/python/utils.py @@ -125,30 +125,6 @@ def extension_for_language(language): # ---------------------------------------------------------- -class Prefix(object): - def __init__(self, paths=None): - if isinstance(paths, list): - self.paths = paths - elif paths is not None: - self.paths = [paths] - else: - self.paths = [] - - def add(self, path): - self.paths.append(path) - - def path(self, filename): - for p in self.paths: - fullpath = os.path.join(p, filename) - if os.path.exists(fullpath): - return fullpath - - return os.path.join(self.paths[0], filename) - - -# ---------------------------------------------------------- - - class File(object): """User helper to read and write file objects""" -- GitLab