diff --git a/bob/extension/config.py b/bob/extension/config.py index 0587336bfc7a292a9cc9a8e95e291fbdf4917433..e0c356ceb8d9dd5e3b18bf76f4e3447dd6321bb8 100644 --- a/bob/extension/config.py +++ b/bob/extension/config.py @@ -68,7 +68,15 @@ def _get_module_filename(module_name): return loader.filename -def _resolve_entry_point_or_modules(paths, entry_point_group): +def _object_name(path, common_name): + path = path.rsplit(':', 1) + name = path[1] if len(path) > 1 else common_name + path = path[0] + return path, name + + +def _resolve_entry_point_or_modules(paths, entry_point_group, + common_name=None): """Resolves a mixture of paths, entry point names, and module names to just paths. For example paths can be: ``paths = ['/tmp/config.py', 'config1', 'bob.extension.config2']``. @@ -80,6 +88,9 @@ def _resolve_entry_point_or_modules(paths, entry_point_group): names, or are module names. entry_point_group : str The entry point group name to search in entry points. + common_name : None or str + It will be used as a default name for object names. See the + common_keyword parameter from :any:`load`. Raises ------ @@ -90,20 +101,24 @@ def _resolve_entry_point_or_modules(paths, entry_point_group): ------- paths : [str] The resolved paths pointing to existing files. - names : [str] + module_names : [str] The valid python module names to bind each of the files to - + object_names : [str] + The name of objects that are supposed to be picked from paths. """ entries = {e.name: e for e in pkg_resources.iter_entry_points(entry_point_group)} + files = [] - names = [] + module_names = [] + object_names = [] for i, path in enumerate(paths): old_path = path module_name = 'user_config' # fixed module name for files with full paths + path, object_name = _object_name(path, common_name) # if it already points to a file if isfile(path): @@ -111,7 +126,9 @@ def _resolve_entry_point_or_modules(paths, entry_point_group): # If it is an entry point name, collect path and module name elif path in entries: - module_name = entries[path].module_name + entry = entries[path] + module_name = entry.module_name + object_name = entry.attrs[0] if entry.attrs else common_name path = _get_module_filename(module_name) if not isfile(path): raise ValueError( @@ -131,12 +148,13 @@ def _resolve_entry_point_or_modules(paths, entry_point_group): old_path, path, entry_point_group or '')) files.append(path) - names.append(module_name) + module_names.append(module_name) + object_names.append(object_name) - return files, names + return files, module_names, object_names -def load(paths, context=None, entry_point_group=None): +def load(paths, context=None, entry_point_group=None, common_keyword=None): '''Loads a set of configuration files, in sequence This method will load one or more configuration files. Every time a @@ -157,17 +175,37 @@ def load(paths, context=None, entry_point_group=None): entry_point_group : :py:class:`str`, optional If provided, it will treat non-existing file paths as entry point names under the ``entry_point_group`` name. + common_keyword : None or str + If provided, will look for the common_keyword variable inside the loaded + files. Paths ending with `some_path:variable_name` can override the + common_keyword. The entry_point_group must provided as well + common_keyword is not None. Returns ------- - mod : :any:`module` + mod : :any:`module` or object A module representing the resolved context, after loading the provided - modules and resolving all variables. + modules and resolving all variables. If common_keyword is given, the + object with the common_keyword name (or the name provided by user) is + returned instead of the module. + + Raises + ------ + ImportError + If common_keyword is given but the object does not exist in the paths. + ValueError + If common_keyword is given but entry_point_group is not given. ''' + if common_keyword and not entry_point_group: + raise ValueError( + "entry_point_group must be provided when using the " + "common_keyword parameter.") + # resolve entry points to paths if entry_point_group is not None: - paths, names = _resolve_entry_point_or_modules(paths, entry_point_group) + paths, names, object_names = _resolve_entry_point_or_modules( + paths, entry_point_group, common_keyword) else: names = len(paths) * ['user_config'] @@ -192,7 +230,18 @@ def load(paths, context=None, entry_point_group=None): LOADED_CONFIGS.append(mod) ctxt = _load_context(k, mod) - return mod + if not common_keyword: + return mod + + # We pick the last object_name here. Normally users should provide just one + # path when enabling the common_keyword parameter. + common_keyword = object_names[-1] + if not hasattr(mod, common_keyword): + raise ImportError( + "The desired variable '%s' does not exist in any of " + "your configuration files: %s" % (common_keyword, ', '.join(paths))) + + return getattr(mod, common_keyword) def mod_to_context(mod): diff --git a/bob/extension/data/resource_config2.py b/bob/extension/data/resource_config2.py new file mode 100644 index 0000000000000000000000000000000000000000..ca505fd6eea4b47b59ed0f381bc99a73f8f6d9d3 --- /dev/null +++ b/bob/extension/data/resource_config2.py @@ -0,0 +1,2 @@ +test_config_load = 1 +b = 2 diff --git a/bob/extension/data/test_dump_config2.py b/bob/extension/data/test_dump_config2.py index 06d28fe6d19f4571744e1578b965062ac77e33e6..67be3154990347a5ee65d375571ace8f59d714dd 100644 --- a/bob/extension/data/test_dump_config2.py +++ b/bob/extension/data/test_dump_config2.py @@ -14,12 +14,12 @@ yyy : callable # database = None '''Required parameter: database (--database, -d) -bla bla bla Can be a ``bob.extension.test_config_load`` entry point, a module name, or a path to a Python file which contains a variable named `database`. +bla bla bla Can be a ``bob.extension.test_dump_config`` entry point, a module name, or a path to a Python file which contains a variable named `database`. Registered entries are: ['basic_config', 'resource_config', 'subpackage_config']''' # annotator = None '''Required parameter: annotator (--annotator, -a) -bli bli bli Can be a ``bob.extension.test_config_load`` entry point, a module name, or a path to a Python file which contains a variable named `annotator`. +bli bli bli Can be a ``bob.extension.test_dump_config`` entry point, a module name, or a path to a Python file which contains a variable named `annotator`. Registered entries are: ['basic_config', 'resource_config', 'subpackage_config']''' # output_dir = None diff --git a/bob/extension/scripts/click_helper.py b/bob/extension/scripts/click_helper.py index 2b2aa84613019a8e33abca7f9c8a9bbcc1610ade..ee0fac5c477c92a6da350fc8272470744cada5c0 100644 --- a/bob/extension/scripts/click_helper.py +++ b/bob/extension/scripts/click_helper.py @@ -339,10 +339,10 @@ class ResourceOption(click.Option): self).full_process_value(ctx, value) if self.entry_point_group is not None: - keyword = self.entry_point_group.split('.')[-1] + common_keyword = self.entry_point_group.split('.')[-1] while isinstance(value, basestring): - value = load([value], entry_point_group=self.entry_point_group) - value = getattr(value, keyword) + value = load([value], entry_point_group=self.entry_point_group, + common_keyword=common_keyword) return value diff --git a/bob/extension/test_click_helper.py b/bob/extension/test_click_helper.py index 540dff9de7c4a509b150a79c95260ed9beaf7f63..41e10f40d58f8fa341e525aeb3847eef77e25550 100644 --- a/bob/extension/test_click_helper.py +++ b/bob/extension/test_click_helper.py @@ -216,11 +216,11 @@ def test_config_dump2(): def cli(): pass - @cli.command(cls=ConfigCommand, entry_point_group='bob.extension.test_config_load') + @cli.command(cls=ConfigCommand, entry_point_group='bob.extension.test_dump_config') @click.option('--database', '-d', required=True, cls=ResourceOption, - entry_point_group='bob.extension.test_config_load', help="bla bla bla") + entry_point_group='bob.extension.test_dump_config', help="bla bla bla") @click.option('--annotator', '-a', required=True, cls=ResourceOption, - entry_point_group='bob.extension.test_config_load', help="bli bli bli") + entry_point_group='bob.extension.test_dump_config', help="bli bli bli") @click.option('--output-dir', '-o', required=True, cls=ResourceOption, help="blo blo blo") @click.option('--force', '-f', is_flag=True, cls=ResourceOption, diff --git a/bob/extension/test_config.py b/bob/extension/test_config.py index ebaaeb80914e21647063ae8de293bf9225f66d9a..13691a222a01961627863103a9c80be5119da931 100644 --- a/bob/extension/test_config.py +++ b/bob/extension/test_config.py @@ -58,3 +58,27 @@ def test_entry_point_configs(): assert hasattr(c, "a") and c.a == 1 assert hasattr(c, "b") and c.b == 3 assert hasattr(c, "rc") + + +def test_load_resource(): + for p, ref in [ + (os.path.join(path, 'resource_config2.py'), 1), + (os.path.join(path, 'resource_config2.py:test_config_load'), 1), + (os.path.join(path, 'resource_config2.py:b'), 2), + ('resource1', 1), + ('resource2', 2), + ('bob.extension.data.resource_config2', 1), + ('bob.extension.data.resource_config2:test_config_load', 1), + ('bob.extension.data.resource_config2:b', 2), + ]: + c = load([p], entry_point_group='bob.extension.test_config_load', + common_keyword='test_config_load') + assert c == ref, c + + try: + load(['bob.extension.data.resource_config2:c'], + entry_point_group='bob.extension.test_config_load', + common_keyword='test_config_load') + assert False, 'The code above should have raised an ImportError' + except ImportError: + pass diff --git a/doc/framework.rst b/doc/framework.rst index 715e76432f07d6b0caa1b2e4386606791634b76f..bf8f45cc7232bbf7343b6805ed496a534ed20f42 100644 --- a/doc/framework.rst +++ b/doc/framework.rst @@ -123,6 +123,32 @@ to provide the group name of the entry points: b = 6 +Resource Loading +================ + +The function :py:func:`bob.extension.config.load` can also only return +variables from paths. To do this, you need provide a common_keyword. For +example, given the following config file: + +.. literalinclude:: ../bob/extension/data/resource_config2.py + :caption: "resource_config2.py" with two variables inside + :language: python + :linenos: + +The loaded value can be either 1 or 2: + +.. doctest:: load_resource + + >>> group = 'bob.extension.test_config_load' # the group name of entry points + >>> common_keyword = 'test_config_load' # the common variable name + >>> value = load(['bob.extension.data.resource_config2'], entry_point_group=group, common_keyword=common_keyword) + >>> value == 1 + True + >>> value = load(['bob.extension.data.resource_config2:b'], entry_point_group=group, common_keyword=common_keyword) + >>> value == 2 + True + + .. _bob.extension.processors: Stacked Processing diff --git a/setup.py b/setup.py index aed3695c4121f840c9b24a0bc8127477c6e9a9b8..ac552cc47513eb2ad9b4ede68240a21c75601e50 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,13 @@ setup( 'basic_config = bob.extension.data.basic_config', 'resource_config = bob.extension.data.resource_config', 'subpackage_config = bob.extension.data.subpackage.config', + 'resource1 = bob.extension.data.resource_config2', + 'resource2 = bob.extension.data.resource_config2:b', + ], + 'bob.extension.test_dump_config': [ + 'basic_config = bob.extension.data.basic_config', + 'resource_config = bob.extension.data.resource_config', + 'subpackage_config = bob.extension.data.subpackage.config', ], }, classifiers=[