Commit 8c1c935f authored by Samuel GAIST's avatar Samuel GAIST
Browse files

[databases] pre-commit cleanup

parent 324a3801
......@@ -59,8 +59,8 @@ from .click_helper import AliasedGroup
logger = logging.getLogger(__name__)
CMD_DB_INDEX = 'index'
CMD_VIEW_OUTPUTS = 'databases_provider'
CMD_DB_INDEX = "index"
CMD_VIEW_OUTPUTS = "databases_provider"
# ----------------------------------------------------------
......@@ -68,7 +68,7 @@ CMD_VIEW_OUTPUTS = 'databases_provider'
def load_database_sets(configuration, database_name):
# Process the name of the database
parts = database_name.split('/')
parts = database_name.split("/")
if len(parts) == 2:
db_name = os.path.join(*parts[:2])
......@@ -86,19 +86,21 @@ def load_database_sets(configuration, database_name):
set_filter = parts[3]
else:
logger.error("Database specification should have the format "
"`<database>/<version>/[<protocol>/[<set>]]', the value "
"you passed (%s) is not valid", database_name)
logger.error(
"Database specification should have the format "
"`<database>/<version>/[<protocol>/[<set>]]', the value "
"you passed (%s) is not valid",
database_name,
)
return (None, None)
# Load the dataformat
dataformat_cache = {}
database = Database(configuration.path,
db_name, dataformat_cache)
database = Database(configuration.path, db_name, dataformat_cache)
if not database.valid:
logger.error("Failed to load the database `%s':", db_name)
for e in database.errors:
logger.error(' * %s', e)
logger.error(" * %s", e)
return (None, None, None)
# Filter the protocols
......@@ -106,9 +108,13 @@ def load_database_sets(configuration, database_name):
if protocol_filter is not None:
if protocol_filter not in protocols:
logger.error("The database `%s' does not have the protocol `%s' - "
"choose one of `%s'", db_name, protocol_filter,
', '.join(protocols))
logger.error(
"The database `%s' does not have the protocol `%s' - "
"choose one of `%s'",
db_name,
protocol_filter,
", ".join(protocols),
)
return (None, None, None)
......@@ -122,17 +128,24 @@ def load_database_sets(configuration, database_name):
if set_filter is not None:
if set_filter not in sets:
logger.error("The database/protocol `%s/%s' does not have the "
"set `%s' - choose one of `%s'",
db_name, protocol_name, set_filter,
', '.join(sets))
logger.error(
"The database/protocol `%s/%s' does not have the "
"set `%s' - choose one of `%s'",
db_name,
protocol_name,
set_filter,
", ".join(sets),
)
return (None, None, None)
sets = [z for z in sets if z == set_filter]
loaded_sets.extend([(protocol_name, set_name,
database.set(protocol_name, set_name))
for set_name in sets])
loaded_sets.extend(
[
(protocol_name, set_name, database.set(protocol_name, set_name))
for set_name in sets
]
)
return (db_name, database, loaded_sets)
......@@ -140,84 +153,92 @@ def load_database_sets(configuration, database_name):
# ----------------------------------------------------------
def start_db_container(configuration, cmd, host,
db_name, protocol_name, set_name, database, db_set,
excluded_outputs=None, uid=None, db_root=None):
def start_db_container(
configuration,
cmd,
host,
db_name,
protocol_name,
set_name,
database,
db_set,
excluded_outputs=None,
uid=None,
db_root=None,
):
input_list = inputs.InputList()
input_group = inputs.InputGroup(set_name, restricted_access=False)
input_list.add(input_group)
db_configuration = {
'inputs': {},
'channel': set_name,
}
db_configuration = {"inputs": {}, "channel": set_name}
if uid is not None:
db_configuration['datasets_uid'] = uid
db_configuration["datasets_uid"] = uid
if db_root is not None:
db_configuration['datasets_root_path'] = db_root
db_configuration["datasets_root_path"] = db_root
for output_name, dataformat_name in db_set['outputs'].items():
for output_name, dataformat_name in db_set["outputs"].items():
if excluded_outputs is not None and output_name in excluded_outputs:
continue
dataset_hash = hashDataset(db_name, protocol_name, set_name)
db_configuration['inputs'][output_name] = dict(
db_configuration["inputs"][output_name] = dict(
database=db_name,
protocol=protocol_name,
set=set_name,
output=output_name,
channel=set_name,
hash=dataset_hash,
path=toPath(dataset_hash, '.db')
path=toPath(dataset_hash, ".db"),
)
db_tempdir = utils.temporary_directory()
with open(os.path.join(db_tempdir, 'configuration.json'), 'wt') as f:
with open(os.path.join(db_tempdir, "configuration.json"), "wt") as f:
simplejson.dump(db_configuration, f, indent=4)
tmp_prefix = os.path.join(db_tempdir, 'prefix')
tmp_prefix = os.path.join(db_tempdir, "prefix")
if not os.path.exists(tmp_prefix):
os.makedirs(tmp_prefix)
database.export(tmp_prefix)
if db_root is None:
json_path = os.path.join(tmp_prefix, 'databases', db_name + '.json')
json_path = os.path.join(tmp_prefix, "databases", db_name + ".json")
with open(json_path, 'r') as f:
with open(json_path, "r") as f:
db_data = simplejson.load(f)
database_path = db_data['root_folder']
db_data['root_folder'] = os.path.join('/databases', db_name)
database_path = db_data["root_folder"]
db_data["root_folder"] = os.path.join("/databases", db_name)
with open(json_path, 'w') as f:
with open(json_path, "w") as f:
simplejson.dump(db_data, f, indent=4)
try:
db_envkey = host.db2docker([db_name])
except:
raise RuntimeError("No environment found for the database `%s' "
"- available environments are %s" % (
db_name,
", ".join(host.db_environments.keys())))
except Exception:
raise RuntimeError(
"No environment found for the database `%s' "
"- available environments are %s"
% (db_name, ", ".join(host.db_environments.keys()))
)
# Creation of the container
# Note: we only support one databases image loaded at the same time
CONTAINER_PREFIX = '/beat/prefix'
CONTAINER_CACHE = '/beat/cache'
CONTAINER_PREFIX = "/beat/prefix"
CONTAINER_CACHE = "/beat/cache"
database_port = random.randint(51000, 60000)
database_port = random.randint(51000, 60000) # nosec just getting a free port
if cmd == CMD_VIEW_OUTPUTS:
db_cmd = [
cmd,
'0.0.0.0:{}'.format(database_port),
"0.0.0.0:{}".format(database_port),
CONTAINER_PREFIX,
CONTAINER_CACHE
CONTAINER_CACHE,
]
else:
db_cmd = [
......@@ -226,28 +247,32 @@ def start_db_container(configuration, cmd, host,
CONTAINER_CACHE,
db_name,
protocol_name,
set_name
set_name,
]
databases_container = host.create_container(db_envkey, db_cmd)
databases_container.uid = uid
if cmd == CMD_VIEW_OUTPUTS:
databases_container.add_port(
database_port, database_port, host_address=host.ip)
databases_container.add_volume(db_tempdir, '/beat/prefix')
databases_container.add_volume(configuration.cache, '/beat/cache')
databases_container.add_port(database_port, database_port, host_address=host.ip)
databases_container.add_volume(db_tempdir, "/beat/prefix")
databases_container.add_volume(configuration.cache, "/beat/cache")
else:
databases_container.add_volume(tmp_prefix, '/beat/prefix')
databases_container.add_volume(configuration.cache, '/beat/cache', read_only=False)
databases_container.add_volume(tmp_prefix, "/beat/prefix")
databases_container.add_volume(
configuration.cache, "/beat/cache", read_only=False
)
# Specify the volumes to mount inside the container
if 'datasets_root_path' not in db_configuration:
if "datasets_root_path" not in db_configuration:
databases_container.add_volume(
database_path, os.path.join('/databases', db_name))
database_path, os.path.join("/databases", db_name)
)
else:
databases_container.add_volume(db_configuration['datasets_root_path'],
db_configuration['datasets_root_path'])
databases_container.add_volume(
db_configuration["datasets_root_path"],
db_configuration["datasets_root_path"],
)
# Start the container
host.start(databases_container)
......@@ -256,21 +281,21 @@ def start_db_container(configuration, cmd, host,
# Communicate with container
zmq_context = zmq.Context()
db_socket = zmq_context.socket(zmq.PAIR)
db_address = 'tcp://{}:{}'.format(host.ip, database_port)
db_address = "tcp://{}:{}".format(host.ip, database_port)
db_socket.connect(db_address)
for output_name, dataformat_name in db_set['outputs'].items():
if excluded_outputs is not None and \
output_name in excluded_outputs:
for output_name, dataformat_name in db_set["outputs"].items():
if excluded_outputs is not None and output_name in excluded_outputs:
continue
data_source = RemoteDataSource()
data_source.setup(db_socket, output_name,
dataformat_name, configuration.path)
data_source.setup(
db_socket, output_name, dataformat_name, configuration.path
)
input_ = inputs.Input(output_name,
database.dataformats[dataformat_name],
data_source)
input_ = inputs.Input(
output_name, database.dataformats[dataformat_name], data_source
)
input_group.add(input_)
return (databases_container, db_socket, zmq_context, input_list)
......@@ -319,9 +344,15 @@ def pull_impl(webapi, prefix, names, force, indentation, format_cache):
from .dataformats import pull_impl as dataformats_pull
status, names = common.pull(webapi, prefix, 'database', names,
['declaration', 'code', 'description'],
force, indentation)
status, names = common.pull(
webapi,
prefix,
"database",
names,
["declaration", "code", "description"],
force,
indentation,
)
# see what dataformats one needs to pull
dataformats = []
......@@ -330,8 +361,9 @@ def pull_impl(webapi, prefix, names, force, indentation, format_cache):
dataformats.extend(obj.dataformats.keys())
# downloads any formats to which we depend on
df_status = dataformats_pull(webapi, prefix, dataformats, force,
indentation + 2, format_cache)
df_status = dataformats_pull(
webapi, prefix, dataformats, force, indentation + 2, format_cache
)
return status + df_status
......@@ -341,7 +373,7 @@ def pull_impl(webapi, prefix, names, force, indentation, format_cache):
def index_outputs(configuration, names, uid=None, db_root=None, docker=False):
names = common.make_up_local_list(configuration.path, 'database', names)
names = common.make_up_local_list(configuration.path, "database", names)
retcode = 0
if docker:
......@@ -350,8 +382,7 @@ def index_outputs(configuration, names, uid=None, db_root=None, docker=False):
for database_name in names:
logger.info("Indexing database %s...", database_name)
(db_name, database, sets) = load_database_sets(
configuration, database_name)
(db_name, database, sets) = load_database_sets(configuration, database_name)
if database is None:
retcode += 1
continue
......@@ -361,9 +392,8 @@ def index_outputs(configuration, names, uid=None, db_root=None, docker=False):
try:
view = database.view(protocol_name, set_name)
except SyntaxError as error:
logger.error("Failed to load the database `%s':",
database_name)
logger.error(' * Syntax error: %s', error)
logger.error("Failed to load the database `%s':", database_name)
logger.error(" * Syntax error: %s", error)
view = None
if view is None:
......@@ -372,22 +402,28 @@ def index_outputs(configuration, names, uid=None, db_root=None, docker=False):
dataset_hash = hashDataset(db_name, protocol_name, set_name)
try:
view.index(os.path.join(configuration.cache,
toPath(dataset_hash, '.db')))
view.index(
os.path.join(configuration.cache, toPath(dataset_hash, ".db"))
)
except RuntimeError as error:
logger.error("Failed to load the database `%s':",
database_name)
logger.error(' * Runtime error %s', error)
retcode += 1
continue
logger.error("Failed to load the database `%s':", database_name)
logger.error(" * Runtime error %s", error)
retcode += 1
continue
else:
databases_container = \
start_db_container(configuration, CMD_DB_INDEX,
host, db_name, protocol_name, set_name,
database, db_set,
uid=uid, db_root=db_root
)
databases_container = start_db_container(
configuration,
CMD_DB_INDEX,
host,
db_name,
protocol_name,
set_name,
database,
db_set,
uid=uid,
db_root=db_root,
)
status = host.wait(databases_container)
logs = host.logs(databases_container)
host.rm(databases_container)
......@@ -404,15 +440,14 @@ def index_outputs(configuration, names, uid=None, db_root=None, docker=False):
def list_index_files(configuration, names):
names = common.make_up_local_list(configuration.path, 'database', names)
names = common.make_up_local_list(configuration.path, "database", names)
retcode = 0
for database_name in names:
logger.info("Listing database %s indexes...", database_name)
(db_name, database, sets) = load_database_sets(
configuration, database_name)
(db_name, database, sets) = load_database_sets(configuration, database_name)
if database is None:
retcode += 1
continue
......@@ -421,7 +456,7 @@ def list_index_files(configuration, names):
dataset_hash = hashDataset(db_name, protocol_name, set_name)
index_filename = toPath(dataset_hash)
basename = os.path.splitext(index_filename)[0]
for g in glob.glob(basename + '.*'):
for g in glob.glob(basename + ".*"):
logger.info(g)
return retcode
......@@ -432,32 +467,33 @@ def list_index_files(configuration, names):
def delete_index_files(configuration, names):
names = common.make_up_local_list(configuration.path, 'database', names)
names = common.make_up_local_list(configuration.path, "database", names)
retcode = 0
for database_name in names:
logger.info("Deleting database %s indexes...", database_name)
(db_name, database, sets) = load_database_sets(
configuration, database_name)
(db_name, database, sets) = load_database_sets(configuration, database_name)
if database is None:
retcode += 1
continue
for protocol_name, set_name, db_set in sets:
for output_name in db_set['outputs'].keys():
for output_name in db_set["outputs"].keys():
dataset_hash = hashDataset(db_name, protocol_name, set_name)
index_filename = toPath(dataset_hash)
basename = os.path.join(configuration.cache,
os.path.splitext(index_filename)[0])
basename = os.path.join(
configuration.cache, os.path.splitext(index_filename)[0]
)
for g in glob.glob(basename + '.*'):
for g in glob.glob(basename + ".*"):
logger.info("removing `%s'...", g)
os.unlink(g)
common.recursive_rmdir_if_empty(os.path.dirname(basename),
configuration.cache)
common.recursive_rmdir_if_empty(
os.path.dirname(basename), configuration.cache
)
return retcode
......@@ -465,20 +501,27 @@ def delete_index_files(configuration, names):
# ----------------------------------------------------------
def view_outputs(configuration, dataset_name, excluded_outputs=None, uid=None,
db_root=None, docker=False):
def view_outputs(
configuration,
dataset_name,
excluded_outputs=None,
uid=None,
db_root=None,
docker=False,
):
def data_to_json(data, indent):
value = common.stringify(data.as_dict())
value = simplejson.dumps(value, indent=4, cls=NumpyJSONEncoder) \
.replace('"BEAT_LIST_DELIMITER[', '[') \
.replace(']BEAT_LIST_DELIMITER"', ']') \
.replace('"...",', '...') \
.replace('"BEAT_LIST_SIZE(', '(') \
.replace(')BEAT_LIST_SIZE"', ')')
value = (
simplejson.dumps(value, indent=4, cls=NumpyJSONEncoder)
.replace('"BEAT_LIST_DELIMITER[', "[")
.replace(']BEAT_LIST_DELIMITER"', "]")
.replace('"...",', "...")
.replace('"BEAT_LIST_SIZE(', "(")
.replace(')BEAT_LIST_SIZE"', ")")
)
return ('\n' + ' ' * indent).join(value.split('\n'))
return ("\n" + " " * indent).join(value.split("\n"))
# Load the infos about the database set
(db_name, database, sets) = load_database_sets(configuration, dataset_name)
......@@ -488,8 +531,7 @@ def view_outputs(configuration, dataset_name, excluded_outputs=None, uid=None,
(protocol_name, set_name, db_set) = sets[0]
if excluded_outputs is not None:
excluded_outputs = map(lambda x: x.strip(),
excluded_outputs.split(','))
excluded_outputs = map(lambda x: x.strip(), excluded_outputs.split(","))
# Setup the view so the outputs can be used
if not docker:
......@@ -499,29 +541,38 @@ def view_outputs(configuration, dataset_name, excluded_outputs=None, uid=None,
return 1
dataset_hash = hashDataset(db_name, protocol_name, set_name)
view.setup(os.path.join(configuration.cache,
toPath(dataset_hash, '.db')), pack=False)
view.setup(
os.path.join(configuration.cache, toPath(dataset_hash, ".db")), pack=False
)
input_group = inputs.InputGroup(set_name, restricted_access=False)
for output_name, dataformat_name in db_set['outputs'].items():
if excluded_outputs is not None and \
output_name in excluded_outputs:
for output_name, dataformat_name in db_set["outputs"].items():
if excluded_outputs is not None and output_name in excluded_outputs:
continue
input = inputs.Input(output_name,
database.dataformats[dataformat_name],
view.data_sources[output_name])
input = inputs.Input(
output_name,
database.dataformats[dataformat_name],
view.data_sources[output_name],
)
input_group.add(input)
else:
host = dock.Host(raise_on_errors=False)
(databases_container, db_socket, zmq_context, input_list) = \
start_db_container(configuration, CMD_VIEW_OUTPUTS,
host, db_name, protocol_name,
set_name, database, db_set,
excluded_outputs=excluded_outputs,
uid=uid, db_root=db_root)
(databases_container, db_socket, zmq_context, input_list) = start_db_container(
configuration,
CMD_VIEW_OUTPUTS,
host,
db_name,
protocol_name,
set_name,
database,
db_set,
excluded_outputs=excluded_outputs,
uid=uid,
db_root=db_root,
)
input_group = input_list.group(set_name)
......@@ -538,25 +589,28 @@ def view_outputs(configuration, dataset_name, excluded_outputs=None, uid=None,
end = input_group.data_index_end
if start != previous_start:
print(80 * '-')
print(80 * "-")
print('FROM %d TO %d' % (start, end))
print("FROM %d TO %d" % (start, end))
whole_inputs = [input_ for input_ in input_group
if input_.data_index == start and
input_.data_index_end == end]
whole_inputs = [
input_
for input_ in input_group
if input_.data_index == start and input_.data_index_end == end
]
for input in whole_inputs:
label = ' - ' + str(input.name) + ': '
label = " - " + str(input.name) + ": "
print(label + data_to_json(input.data, len(label)))
previous_start = start
selected_inputs = \
[input_ for input_ in input_group
if input_.data_index == input_group.first_data_index and
(input_.data_index != start or
input_.data_index_end != end)]
selected_inputs = [
input_
for input_ in input_group
if input_.data_index == input_group.first_data_index
and (input_.data_index != start or input_.data_index_end != end)
]
grouped_inputs = {}
for input_ in selected_inputs:
......@@ -569,10 +623,10 @@ def view_outputs(configuration, dataset_name, excluded_outputs=None, uid=None,
for key in sorted_keys:
print
print(' FROM %d TO %d' % key)
print(" FROM %d TO %d" % key)
for input in grouped_inputs[key]:
label = ' - ' + str(input.name) + ': '
label = " - " + str(input.name) + ": "
print(label + data_to_json(input.data, len(label)))
except Exception as e:
......@@ -587,7 +641,6 @@ def view_outputs(configuration, dataset_name, excluded_outputs=None, uid=None,
if status != 0:
logger.error("Docker error: %s", logs)