Commit 51c47c13 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

[black]

[black]
parent 6438fe68
......@@ -12,99 +12,139 @@ from bob.extension.download import download_and_unzip
# based on: http://stackoverflow.com/questions/6796492/temporarily-redirect-stdout-stderr
class Quiet(object):
"""A class that supports the ``with`` statement to redirect any output of wrapped function calls to /dev/null"""
def __init__(self):
devnull = open(os.devnull, 'w')
self._stdout = devnull
self._stderr = devnull
"""A class that supports the ``with`` statement to redirect any output of wrapped function calls to /dev/null"""
def __enter__(self):
self.old_stdout, self.old_stderr = sys.stdout, sys.stderr
self.old_stdout.flush(); self.old_stderr.flush()
sys.stdout, sys.stderr = self._stdout, self._stderr
def __init__(self):
devnull = open(os.devnull, "w")
self._stdout = devnull
self._stderr = devnull
def __exit__(self, exc_type, exc_value, traceback):
self._stdout.flush(); self._stderr.flush()
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
def __enter__(self):
self.old_stdout, self.old_stderr = sys.stdout, sys.stderr
self.old_stdout.flush()
self.old_stderr.flush()
sys.stdout, sys.stderr = self._stdout, self._stderr
def __exit__(self, exc_type, exc_value, traceback):
self._stdout.flush()
self._stderr.flush()
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
import logging
logger = logging.getLogger("bob.bio.base")
def random_array(shape, minimum = 0, maximum = 1, seed = 42):
# generate a random sequence of features
numpy.random.seed(seed)
return numpy.random.random(shape) * (maximum - minimum) + minimum
def random_array(shape, minimum=0, maximum=1, seed=42):
# generate a random sequence of features
numpy.random.seed(seed)
return numpy.random.random(shape) * (maximum - minimum) + minimum
def random_training_set(shape, count, minimum=0, maximum=1, seed=42):
"""Returns a random training set with the given shape and the given number of elements."""
# generate a random sequence of features
numpy.random.seed(seed)
return [
numpy.random.random(shape) * (maximum - minimum) + minimum for i in range(count)
]
def random_training_set(shape, count, minimum = 0, maximum = 1, seed = 42):
"""Returns a random training set with the given shape and the given number of elements."""
# generate a random sequence of features
numpy.random.seed(seed)
return [numpy.random.random(shape) * (maximum - minimum) + minimum for i in range(count)]
def random_training_set_by_id(shape, count = 50, minimum = 0, maximum = 1, seed = 42):
# generate a random sequence of features
numpy.random.seed(seed)
train_set = []
for i in range(count):
train_set.append([numpy.random.random(shape) * (maximum - minimum) + minimum for j in range(count)])
return train_set
def random_training_set_by_id(shape, count=50, minimum=0, maximum=1, seed=42):
# generate a random sequence of features
numpy.random.seed(seed)
train_set = []
for i in range(count):
train_set.append(
[
numpy.random.random(shape) * (maximum - minimum) + minimum
for j in range(count)
]
)
return train_set
def grid_available(test):
'''Decorator to check if the gridtk is present, before running the test'''
@functools.wraps(test)
def wrapper(*args, **kwargs):
try:
import gridtk
return test(*args, **kwargs)
except ImportError as e:
raise SkipTest("Skipping test since gridtk is not available: %s" % e)
return wrapper
"""Decorator to check if the gridtk is present, before running the test"""
def db_available(dbname):
'''Decorator that checks if a given bob.db database is available.
This is a double-indirect decorator, see http://thecodeship.com/patterns/guide-to-python-function-decorators'''
def wrapped_function(test):
@functools.wraps(test)
def wrapper(*args, **kwargs):
try:
__import__('bob.db.%s' % dbname)
return test(*args, **kwargs)
except ImportError as e:
raise SkipTest("Skipping test since the database bob.db.%s seems not to be available: %s" % (dbname,e))
try:
import gridtk
return test(*args, **kwargs)
except ImportError as e:
raise SkipTest("Skipping test since gridtk is not available: %s" % e)
return wrapper
return wrapped_function
atnt_default_directory = os.environ['ATNT_DATABASE_DIRECTORY'] if 'ATNT_DATABASE_DIRECTORY' in os.environ else "/idiap/group/biometric/databases/orl/"
def db_available(dbname):
"""Decorator that checks if a given bob.db database is available.
This is a double-indirect decorator, see http://thecodeship.com/patterns/guide-to-python-function-decorators"""
def wrapped_function(test):
@functools.wraps(test)
def wrapper(*args, **kwargs):
try:
__import__("bob.db.%s" % dbname)
return test(*args, **kwargs)
except ImportError as e:
raise SkipTest(
"Skipping test since the database bob.db.%s seems not to be available: %s"
% (dbname, e)
)
return wrapper
return wrapped_function
atnt_default_directory = (
os.environ["ATNT_DATABASE_DIRECTORY"]
if "ATNT_DATABASE_DIRECTORY" in os.environ
else "/idiap/group/biometric/databases/orl/"
)
global atnt_downloaded_directory
atnt_downloaded_directory = None
def atnt_database_directory():
global atnt_downloaded_directory
if atnt_downloaded_directory:
return atnt_downloaded_directory
if os.path.exists(atnt_default_directory):
return atnt_default_directory
# TODO: THIS SHOULD BE A CLASS METHOD OF bob.db.atnt database
source_url = ['http://bobconda.lab.idiap.ch/public/data/bob/att_faces.zip',
'http://www.idiap.ch/software/bob/data/bob/att_faces.zip']
import tempfile
atnt_downloaded_directory = tempfile.mkdtemp(prefix='atnt_db_')
logger.warn("Downloading the AT&T database from '%s' to '%s' ...", source_url, atnt_downloaded_directory)
logger.warn("To avoid this, please download the database manually, extract the data and set the ATNT_DATABASE_DIRECTORY environment variable to this directory.")
# to avoid re-downloading in parallel test execution
os.environ['ATNT_DATABASE_DIRECTORY'] = atnt_downloaded_directory
def atnt_database_directory():
global atnt_downloaded_directory
if atnt_downloaded_directory:
return atnt_downloaded_directory
if os.path.exists(atnt_default_directory):
return atnt_default_directory
# TODO: THIS SHOULD BE A CLASS METHOD OF bob.db.atnt database
source_url = [
"http://bobconda.lab.idiap.ch/public/data/bob/att_faces.zip",
"http://www.idiap.ch/software/bob/data/bob/att_faces.zip",
]
import tempfile
atnt_downloaded_directory = tempfile.mkdtemp(prefix="atnt_db_")
logger.warn(
"Downloading the AT&T database from '%s' to '%s' ...",
source_url,
atnt_downloaded_directory,
)
logger.warn(
"To avoid this, please download the database manually, extract the data and set the ATNT_DATABASE_DIRECTORY environment variable to this directory."
)
# to avoid re-downloading in parallel test execution
os.environ["ATNT_DATABASE_DIRECTORY"] = atnt_downloaded_directory
if not os.path.exists(atnt_downloaded_directory):
os.mkdir(atnt_downloaded_directory)
download_and_unzip(
source_url, os.path.join(atnt_downloaded_directory, "att_faces.zip")
)
if not os.path.exists(atnt_downloaded_directory):
os.mkdir(atnt_downloaded_directory)
download_and_unzip(source_url, os.path.join(atnt_downloaded_directory, "att_faces.zip"))
return atnt_downloaded_directory
return atnt_downloaded_directory
......@@ -9,22 +9,34 @@ import os
import pkg_resources
import bob.extension.config
import sys
if sys.version_info[0] == 2:
from string import letters as ascii_letters
from string import letters as ascii_letters
else:
from string import ascii_letters
from string import ascii_letters
import six
import functools
import logging
logger = logging.getLogger("bob.bio.base")
#: Keywords for which resources are defined.
valid_keywords = ('database', 'preprocessor', 'extractor', 'algorithm', 'grid', 'config', 'annotator', 'baseline', 'pipeline')
valid_keywords = (
"database",
"preprocessor",
"extractor",
"algorithm",
"grid",
"config",
"annotator",
"baseline",
"pipeline",
)
def _collect_config(paths):
'''Collect all python file resources into a module
"""Collect all python file resources into a module
This function recursively loads python modules (in a Python 3-compatible way)
so the last loaded module corresponds to the final state of the loading. In
......@@ -43,13 +55,13 @@ def _collect_config(paths):
A valid Python module you can use to configure your tool
'''
"""
return bob.extension.config.load(paths, entry_point_group="bob.bio.config")
return bob.extension.config.load(paths, entry_point_group="bob.bio.config")
def read_config_file(filenames, keyword = None):
"""read_config_file(filenames, keyword = None) -> config
def read_config_file(filenames, keyword=None):
"""read_config_file(filenames, keyword = None) -> config
Use this function to read the given configuration file.
If a keyword is specified, only the configuration according to this keyword is returned.
......@@ -72,29 +84,42 @@ def read_config_file(filenames, keyword = None):
Otherwise, the whole configuration is returned (as a local namespace).
"""
if not filenames:
raise RuntimeError("At least one configuration file, resource or " \
"module name must be passed")
if not filenames:
raise RuntimeError(
"At least one configuration file, resource or " "module name must be passed"
)
config = _collect_config(filenames)
config = _collect_config(filenames)
if not keyword:
return config
if not keyword:
return config
if not hasattr(config, keyword):
raise ImportError("The desired keyword '%s' does not exist in any of " \
"your configuration files: %s" %(keyword, ', '.join(filenames)))
if not hasattr(config, keyword):
raise ImportError(
"The desired keyword '%s' does not exist in any of "
"your configuration files: %s" % (keyword, ", ".join(filenames))
)
return getattr(config, keyword)
return getattr(config, keyword)
def _get_entry_points(keyword, strip = [], package_prefix='bob.bio.'):
"""Returns the list of entry points for registered resources with the given keyword."""
return [entry_point for entry_point in pkg_resources.iter_entry_points(package_prefix + keyword) if not entry_point.name.startswith(tuple(strip))]
def _get_entry_points(keyword, strip=[], package_prefix="bob.bio."):
"""Returns the list of entry points for registered resources with the given keyword."""
return [
entry_point
for entry_point in pkg_resources.iter_entry_points(package_prefix + keyword)
if not entry_point.name.startswith(tuple(strip))
]
def load_resource(resource, keyword, imports = ['bob.bio.base'], package_prefix='bob.bio.', preferred_package=None):
"""load_resource(resource, keyword, imports = ['bob.bio.base'], package_prefix='bob.bio.', preferred_package = None) -> resource
def load_resource(
resource,
keyword,
imports=["bob.bio.base"],
package_prefix="bob.bio.",
preferred_package=None,
):
"""load_resource(resource, keyword, imports = ['bob.bio.base'], package_prefix='bob.bio.', preferred_package = None) -> resource
Loads the given resource that is registered with the given keyword.
The resource can be:
......@@ -127,62 +152,81 @@ def load_resource(resource, keyword, imports = ['bob.bio.base'], package_prefix=
The resulting resource object is returned, either read from file or resource, or created newly.
"""
# first, look if the resource is a file name
if os.path.isfile(resource):
return read_config_file([resource], keyword)
if keyword not in valid_keywords:
logger.warning("The given keyword '%s' is not valid. Please use one of %s!", keyword, valid_keywords)
# first, look if the resource is a file name
if os.path.isfile(resource):
return read_config_file([resource], keyword)
# now, we check if the resource is registered as an entry point in the resource files
entry_points = [entry_point for entry_point in _get_entry_points(keyword, package_prefix=package_prefix) if entry_point.name == resource]
if len(entry_points):
if len(entry_points) == 1:
return entry_points[0].load()
else:
# TODO: extract current package name and use this one, if possible
# Now: check if there are only two entry points, and one is from the bob.bio.base, then use the other one
index = -1
if preferred_package is not None:
for i,p in enumerate(entry_points):
if p.dist.project_name == preferred_package:
index = i
break
if index == -1:
# by default, use the first one that is not from bob.bio
for i,p in enumerate(entry_points):
if not p.dist.project_name.startswith(package_prefix):
index = i
break
if index != -1:
logger.debug("RESOURCES: Using the resource '%s' from '%s', and ignoring the one from '%s'", resource, entry_points[index].module_name, entry_points[1-index].module_name)
return entry_points[index].load()
else:
logger.warn("Under the desired name '%s', there are multiple entry points defined, we return the first one: %s", resource, [entry_point.module_name for entry_point in entry_points])
return entry_points[0].load()
# if the resource is neither a config file nor an entry point,
# just execute it as a command
try:
# first, execute all import commands that are required
for i in imports:
exec ("import %s"%i)
# now, evaluate the resource (re-evaluate if the resource is still a string)
while isinstance(resource, six.string_types):
resource = eval(resource)
return resource
if keyword not in valid_keywords:
logger.warning(
"The given keyword '%s' is not valid. Please use one of %s!",
keyword,
valid_keywords,
)
except Exception as e:
raise ImportError("The given command line option '%s' is neither a resource for a '%s', nor an existing configuration file, nor could be interpreted as a command (error: %s)"%(resource, keyword, str(e)))
# now, we check if the resource is registered as an entry point in the resource files
entry_points = [
entry_point
for entry_point in _get_entry_points(keyword, package_prefix=package_prefix)
if entry_point.name == resource
]
if len(entry_points):
if len(entry_points) == 1:
return entry_points[0].load()
else:
# TODO: extract current package name and use this one, if possible
# Now: check if there are only two entry points, and one is from the bob.bio.base, then use the other one
index = -1
if preferred_package is not None:
for i, p in enumerate(entry_points):
if p.dist.project_name == preferred_package:
index = i
break
if index == -1:
# by default, use the first one that is not from bob.bio
for i, p in enumerate(entry_points):
if not p.dist.project_name.startswith(package_prefix):
index = i
break
if index != -1:
logger.debug(
"RESOURCES: Using the resource '%s' from '%s', and ignoring the one from '%s'",
resource,
entry_points[index].module_name,
entry_points[1 - index].module_name,
)
return entry_points[index].load()
else:
logger.warn(
"Under the desired name '%s', there are multiple entry points defined, we return the first one: %s",
resource,
[entry_point.module_name for entry_point in entry_points],
)
return entry_points[0].load()
# if the resource is neither a config file nor an entry point,
# just execute it as a command
try:
# first, execute all import commands that are required
for i in imports:
exec("import %s" % i)
# now, evaluate the resource (re-evaluate if the resource is still a string)
while isinstance(resource, six.string_types):
resource = eval(resource)
return resource
except Exception as e:
raise ImportError(
"The given command line option '%s' is neither a resource for a '%s', nor an existing configuration file, nor could be interpreted as a command (error: %s)"
% (resource, keyword, str(e))
)
def extensions(keywords=valid_keywords, package_prefix='bob.bio.'):
"""extensions(keywords=valid_keywords, package_prefix='bob.bio.') -> extensions
def extensions(keywords=valid_keywords, package_prefix="bob.bio."):
"""extensions(keywords=valid_keywords, package_prefix='bob.bio.') -> extensions
Returns a list of packages that define extensions using the given keywords.
......@@ -195,63 +239,100 @@ def extensions(keywords=valid_keywords, package_prefix='bob.bio.'):
package_prefix : str
Package namespace, in which we search for entry points, e.g., ``bob.bio``.
"""
entry_points = [entry_point for keyword in keywords for entry_point in _get_entry_points(keyword, package_prefix=package_prefix)]
return sorted(list(set(entry_point.dist.project_name for entry_point in entry_points)))
def resource_keys(keyword, exclude_packages=[], package_prefix='bob.bio.', strip=['dummy']):
"""Reads and returns all resources that are registered with the given keyword.
entry_points = [
entry_point
for keyword in keywords
for entry_point in _get_entry_points(keyword, package_prefix=package_prefix)
]
return sorted(
list(set(entry_point.dist.project_name for entry_point in entry_points))
)
def resource_keys(
keyword, exclude_packages=[], package_prefix="bob.bio.", strip=["dummy"]
):
"""Reads and returns all resources that are registered with the given keyword.
Entry points from the given ``exclude_packages`` are ignored."""
ret_list = [entry_point.name for entry_point in
_get_entry_points(keyword, strip=strip, package_prefix=package_prefix)
if entry_point.dist.project_name not in exclude_packages]
return sorted(ret_list)
def list_resources(keyword, strip=['dummy'], package_prefix='bob.bio.', verbose=False, packages=None):
"""Returns a string containing a detailed list of resources that are registered with the given keyword."""
if keyword not in valid_keywords:
raise ValueError("The given keyword '%s' is not valid. Please use one of %s!" % (str(keyword), str(valid_keywords)))
entry_points = _get_entry_points(keyword, strip, package_prefix=package_prefix)
last_dist = None
retval = ""
length = max(len(entry_point.name) for entry_point in entry_points) if entry_points else 1
if packages is not None:
entry_points = [entry_point for entry_point in entry_points if entry_point.dist.project_name in packages]
for entry_point in sorted(entry_points, key=lambda p: (p.dist.project_name, p.name)):
if last_dist != str(entry_point.dist):
retval += "\n- %s @ %s: \n" % (str(entry_point.dist), str(entry_point.dist.location))
last_dist = str(entry_point.dist)
if len(entry_point.attrs):
retval += " + %s --> %s: %s\n" % (entry_point.name + " "*(length - len(entry_point.name)), entry_point.module_name, entry_point.attrs[0])
else:
retval += " + %s --> %s\n" % (entry_point.name + " "*(length - len(entry_point.name)), entry_point.module_name)
if verbose:
retval += " ==> " + str(entry_point.load()) + "\n\n"
return retval
def database_directories(strip=['dummy'], replacements = None, package_prefix='bob.bio.'):
"""Returns a dictionary of original directories for all registered databases."""
entry_points = _get_entry_points('database', strip, package_prefix=package_prefix)
dirs = {}
for entry_point in sorted(entry_points, key=lambda entry_point: entry_point.name):
try:
db = load_resource(entry_point.name, 'database')
db.replace_directories(replacements)
dirs[entry_point.name] = [db.original_directory]
if db.annotation_directory is not None:
dirs[entry_point.name].append(db.annotation_directory)
except (AttributeError, ValueError, ImportError):
pass
ret_list = [
entry_point.name
for entry_point in _get_entry_points(
keyword, strip=strip, package_prefix=package_prefix
)
if entry_point.dist.project_name not in exclude_packages
]
return sorted(ret_list)
def list_resources(
keyword, strip=["dummy"], package_prefix="bob.bio.", verbose=False, packages=None
):
"""Returns a string containing a detailed list of resources that are registered with the given keyword."""
if keyword not in valid_keywords:
raise ValueError(
"The given keyword '%s' is not valid. Please use one of %s!"
% (str(keyword), str(valid_keywords))
)
return dirs
entry_points = _get_entry_points(keyword, strip, package_prefix=package_prefix)
last_dist = None
retval = ""
length = (
max(len(entry_point.name) for entry_point in entry_points)
if entry_points
else 1
)
if packages is not None:
entry_points = [
entry_point
for entry_point in entry_points
if entry_point.dist.project_name in packages
]
for entry_point in sorted(
entry_points, key=lambda p: (p.dist.project_name, p.name)
):
if last_dist != str(entry_point.dist):
retval += "\n- %s @ %s: \n" % (
str(entry_point.dist),
str(entry_point.dist.location),
)
last_dist = str(entry_point.dist)
if len(entry_point.attrs):
retval += " + %s --> %s: %s\n" % (
entry_point.name + " " * (length - len(entry_point.name)),
entry_point.module_name,
entry_point.attrs[0],
)
else:
retval += " + %s --> %s\n" % (
entry_point.name + " " * (length - len(entry_point.name)),
entry_point.module_name,
)
if verbose:
retval += " ==> " + str(entry_point.load()) + "\n\n"
return retval
def database_directories(strip=["dummy"], replacements=None, package_prefix="bob.bio."):
"""Returns a dictionary of original directories for all registered databases."""
entry_points = _get_entry_points("database", strip, package_prefix=package_prefix)
dirs = {}
for entry_point in sorted(entry_points, key=lambda entry_point: entry_point.name):
try:
db = load_resource(entry_point.name, "database")
db.replace_directories(replacements)
dirs[entry_point.name] = [db.original_directory]