Commit 6d8f1bb1 authored by Samuel GAIST's avatar Samuel GAIST
Browse files

[database] Refactored start_db_container for new Docker implementation

This will allow to more easily start container if new commands are added
parent 969d8c74
Pipeline #17123 canceled with stage
in 4152 minutes
......@@ -91,6 +91,7 @@ Examples:
import os
import glob
import random
import zmq
import logging
......@@ -102,7 +103,7 @@ from beat.core.hash import hash, toPath, hashFileContents, hashDataset
from beat.core.utils import NumpyJSONEncoder
from beat.core.database import Database
from beat.core.dataformat import DataFormat
from beat.core.data import load_data_index
from beat.core.data import load_data_index, RemoteDataSource
from beat.core import dock
from beat.core import inputs
from beat.core import utils
......@@ -110,6 +111,10 @@ from beat.core import utils
from . import common
CMD_DB_INDEX = 'index'
CMD_VIEW_OUTPUTS = 'databases_provider'
#----------------------------------------------------------
......@@ -187,15 +192,10 @@ def load_database_sets(configuration, database_name):
#----------------------------------------------------------
def start_db_container(host, db_name, protocol_name, set_name, database, db_set,
def start_db_container(configuration, cmd, host,
db_name, protocol_name, set_name, database, db_set,
excluded_outputs=None, uid=None, db_root=None):
zmq_context = zmq.Context()
db_socket = zmq_context.socket(zmq.PAIR)
db_address = 'tcp://' + host.ip
port = db_socket.bind_to_random_port(db_address)
db_address += ':%d' % port
input_list = inputs.InputList()
input_group = inputs.InputGroup(set_name, restricted_access=False)
......@@ -216,15 +216,15 @@ def start_db_container(host, db_name, protocol_name, set_name, database, db_set,
if (excluded_outputs is not None) and (output_name in excluded_outputs):
continue
input = inputs.RemoteInput(output_name, database.dataformats[dataformat_name], db_socket)
input_group.add(input)
dataset_hash = hashDataset(db_name, protocol_name, set_name)
db_configuration['inputs'][output_name] = dict(
database=db_name,
protocol=protocol_name,
set=set_name,
output=output_name,
channel=set_name
channel=set_name,
hash=dataset_hash,
path=toPath(dataset_hash, '.db')
)
db_tempdir = utils.temporary_directory()
......@@ -236,7 +236,7 @@ def start_db_container(host, db_name, protocol_name, set_name, database, db_set,
if not os.path.exists(tmp_prefix):
os.makedirs(tmp_prefix)
database.export(tmp_prefix)
database.export(utils.Prefix(tmp_prefix))
if db_root is None:
json_path = os.path.join(tmp_prefix, 'databases', db_name + '.json')
......@@ -260,14 +260,32 @@ def start_db_container(host, db_name, protocol_name, set_name, database, db_set,
# Creation of the container
# Note: we only support one databases image loaded at the same time
db_cmd = [
'databases_provider',
db_address,
os.path.join('/tmp', os.path.basename(db_tempdir))
]
CONTAINER_PREFIX = '/beat/prefix'
CONTAINER_CACHE = '/beat/cache'
database_port = random.randint(51000, 60000)
if cmd == CMD_VIEW_OUTPUTS:
db_cmd = [
cmd,
'0.0.0.0:{}'.format(database_port),
CONTAINER_PREFIX,
CONTAINER_CACHE
]
else:
db_cmd = [
cmd,
CONTAINER_PREFIX,
CONTAINER_CACHE,
db_name,
protocol_name,
set_name
]
databases_container = host.create_container(db_envkey, db_cmd)
databases_container.copy_path(db_tempdir, '/tmp')
if cmd == CMD_VIEW_OUTPUTS:
databases_container.add_port(database_port, database_port, host_address=host.ip)
databases_container.add_volume(db_tempdir, '/beat/prefix')
databases_container.add_volume(configuration.cache, '/beat/cache')
# Specify the volumes to mount inside the container
if not db_configuration.has_key('datasets_root_path'):
......@@ -279,7 +297,27 @@ def start_db_container(host, db_name, protocol_name, set_name, database, db_set,
# Start the container
host.start(databases_container)
return (databases_container, db_socket, zmq_context, input_list)
if cmd == CMD_VIEW_OUTPUTS:
# Communicate with container
zmq_context = zmq.Context()
db_socket = zmq_context.socket(zmq.PAIR)
db_address = 'tcp://{}:{}'.format(host.ip, database_port)
db_socket.connect(db_address)
for output_name, dataformat_name in db_set['outputs'].items():
if (excluded_outputs is not None) and (output_name in excluded_outputs):
continue
data_source = RemoteDataSource()
data_source.setup(db_socket, output_name, dataformat_name, configuration.path)
input = inputs.Input(output_name, database.dataformats[dataformat_name], data_source)
input_group.add(input)
return (databases_container, db_socket, zmq_context, input_list)
return databases_container
#----------------------------------------------------------
......@@ -373,82 +411,15 @@ def index_outputs(configuration, names, uid=None, db_root=None, docker=False):
view.index(os.path.join(configuration.cache,
toPath(dataset_hash, '.db')))
# TODO: Remove when docker version is fixed:
continue
else:
(databases_container, db_socket, zmq_context, input_list) = \
start_db_container(
databases_container = \
start_db_container(configuration, CMD_DB_INDEX,
host, db_name, protocol_name, set_name, database, db_set,
uid=uid, db_root=db_root
)
input_group = input_list.group(set_name)
index_filenames = []
previous_data_indices = []
for output_name in db_set['outputs'].keys():
index_hash = database.hash_output(protocol_name, set_name, output_name)
index_filename = os.path.join(configuration.cache,
toPath(index_hash, '.index'))
index_filenames.append(index_filename)
previous_data_indices.append(None)
logger.info("Indexing database `%s', protocol `%s', set `%s', " \
"output `%s' -> `%s'", db_name, protocol_name, set_name,
output_name, index_filename)
if os.path.exists(index_filename):
logger.extra("Overwriting existing index file `%s'",
index_filename)
os.remove(index_filename)
index_dir = os.path.dirname(index_filename)
if not os.path.exists(index_dir):
logger.extra("Creating directory `%s'", index_dir)
os.makedirs(index_dir)
# Create empty lock file to indicate that this cache file *must*
# not be cleaned
lock_filename = os.path.join(configuration.cache,
toPath(index_hash, '.lock'))
with open(lock_filename, 'a'):
os.utime(lock_filename, None)
try:
while input_group.hasMoreData():
input_group.next()
for i, (input, index_filename, previous_data_index) in \
enumerate(zip(input_group, index_filenames,
previous_data_indices)):
with open(index_filename, 'at') as indexfile:
if input.data_index != previous_data_index:
indexfile.write("%d %d\n" % (input.data_index,
input.data_index_end)
)
previous_data_indices[i] = input.data_index
# creates the checksums for all indexes
chksum = hashFileContents(index_filename)
with open(index_filename + '.checksum', 'wt') as f:
f.write(chksum)
except Exception as e:
logger.critical("Failed to retrieve the next data: %s", e)
retcode += 1
continue
if docker:
host.kill(databases_container)
host.wait(databases_container)
db_socket.setsockopt(zmq.LINGER, 0)
db_socket.close()
zmq_context.term()
status = host.wait(databases_container)
if status != 0:
retcode += 1
return retcode
......@@ -599,7 +570,7 @@ def view_outputs(configuration, dataset_name, excluded_outputs=None, uid=None,
host = dock.Host(raise_on_errors=False)
(databases_container, db_socket, zmq_context, input_list) = \
start_db_container(
start_db_container(configuration, CMD_VIEW_OUTPUTS,
host, db_name, protocol_name, set_name, database, db_set,
excluded_outputs=excluded_outputs, uid=uid, db_root=db_root
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment