#!/usr/bin/env python # vim: set fileencoding=utf-8 : ############################################################################### # # # Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/ # # Contact: beat.support@idiap.ch # # # # This file is part of the beat.cmdline module of the BEAT platform. # # # # Commercial License Usage # # Licensees holding valid commercial BEAT licenses may use this file in # # accordance with the terms contained in a written agreement between you # # and Idiap. For further information contact tto@idiap.ch # # # # Alternatively, this file may be used under the terms of the GNU Affero # # Public License version 3 as published by the Free Software and appearing # # in the file LICENSE.AGPL included in the packaging of this file. # # The BEAT platform is distributed in the hope that it will be useful, but # # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # # or FITNESS FOR A PARTICULAR PURPOSE. # # # # You should have received a copy of the GNU Affero Public License along # # with the BEAT platform. If not, see http://www.gnu.org/licenses/. # # # ############################################################################### """Usage: %(prog)s databases list [--remote] %(prog)s databases check []... %(prog)s databases pull [--force] []... %(prog)s databases push [--force] [--dry-run] []... %(prog)s databases diff %(prog)s databases status %(prog)s databases version %(prog)s databases index [--list | --delete | --checksum] [--uid=] [--db-root=] [--docker] []... %(prog)s databases view [--exclude=] [--uid=] [--db-root=] [--docker] %(prog)s databases --help Arguments: Database name formated as "/" Set formatted as "///" Commands: list Lists all the databases available on the platform check Checks a local database for validity pull Downloads the specified databases from the server push Uploads databases to the server (must provide a valid admin token) diff Shows changes between the local database and the remote version status Shows (editing) status for all available databases version Creates a new version of an existing database index Indexes all outputs (of all sets) of a database. view View the data of the specified dataset. Options: --help Display this screen --remote Only acts on the remote copy of the database --exclude= When viewing, excludes this output --list List index files matching output if they exist --delete Delete index files matching output if they exist (also, recursively deletes empty directories) --checksum Checksums index files Examples: To list all existing databases on your local prefix: $ %(prog)s db list To view the contents of a specific set $ %(prog)s db view simple/1/protocol/set To index the contents of a database $ %(prog)s db index simple/1 To index the contents of a protocol on a database $ %(prog)s db index simple/1/double To index the contents of a set in a protocol on a database $ %(prog)s db index simple/1/double/double """ import os import glob import random import zmq import logging logger = logging.getLogger(__name__) import simplejson from beat.core.hash import hash, toPath, hashFileContents, hashDataset from beat.core.utils import NumpyJSONEncoder from beat.core.database import Database from beat.core.dataformat import DataFormat from beat.core.data import load_data_index, RemoteDataSource from beat.core import dock from beat.core import inputs from beat.core import utils from . import common CMD_DB_INDEX = 'index' CMD_VIEW_OUTPUTS = 'databases_provider' #---------------------------------------------------------- def load_database_sets(configuration, database_name): # Process the name of the database parts = database_name.split('/') if len(parts) == 2: db_name = os.path.join(*parts[:2]) protocol_filter = None set_filter = None elif len(parts) == 3: db_name = os.path.join(*parts[:2]) protocol_filter = parts[2] set_filter = None elif len(parts) == 4: db_name = os.path.join(*parts[:2]) protocol_filter = parts[2] set_filter = parts[3] else: logger.error("Database specification should have the format " \ "`//[/[]]', the value you " \ "passed (%s) is not valid", (dataset_name)) return (None, None) # Load the dataformat dataformat_cache = {} database = Database(utils.Prefix(configuration.path), db_name, dataformat_cache) if not database.valid: logger.error("Failed to load the database `%s':", db_name) for e in database.errors: logger.error(' * %s', e) return (None, None, None) # Filter the protocols protocols = database.protocol_names if protocol_filter is not None: if protocol_filter not in protocols: logger.error("The database `%s' does not have the protocol `%s' - " \ "choose one of `%s'", db_name, protocol_filter, ', '.join(protocols)) return (None, None, None) protocols = [protocol_filter] # Filter the sets loaded_sets = [] for protocol_name in protocols: sets = database.set_names(protocol_name) if set_filter is not None: if set_filter not in sets: logger.error("The database/protocol `%s/%s' does not have the " \ "set `%s' - choose one of `%s'", db_name, protocol_name, set_filter, ', '.join(sets)) return (None, None, None) sets = [z for z in sets if z == set_filter] loaded_sets.extend([ (protocol_name, set_name, database.set(protocol_name, set_name)) for set_name in sets ]) return (db_name, database, loaded_sets) #---------------------------------------------------------- def start_db_container(configuration, cmd, host, db_name, protocol_name, set_name, database, db_set, excluded_outputs=None, uid=None, db_root=None): input_list = inputs.InputList() input_group = inputs.InputGroup(set_name, restricted_access=False) input_list.add(input_group) db_configuration = { 'inputs': {}, 'channel': set_name, } if uid is not None: db_configuration['datasets_uid'] = uid if db_root is not None: db_configuration['datasets_root_path'] = db_root for output_name, dataformat_name in db_set['outputs'].items(): if (excluded_outputs is not None) and (output_name in excluded_outputs): continue dataset_hash = hashDataset(db_name, protocol_name, set_name) db_configuration['inputs'][output_name] = dict( database=db_name, protocol=protocol_name, set=set_name, output=output_name, channel=set_name, hash=dataset_hash, path=toPath(dataset_hash, '.db') ) db_tempdir = utils.temporary_directory() with open(os.path.join(db_tempdir, 'configuration.json'), 'wb') as f: simplejson.dump(db_configuration, f, indent=4) tmp_prefix = os.path.join(db_tempdir, 'prefix') if not os.path.exists(tmp_prefix): os.makedirs(tmp_prefix) database.export(utils.Prefix(tmp_prefix)) if db_root is None: json_path = os.path.join(tmp_prefix, 'databases', db_name + '.json') with open(json_path, 'r') as f: db_data = simplejson.load(f) database_path = db_data['root_folder'] db_data['root_folder'] = os.path.join('/databases', db_name) with open(json_path, 'w') as f: simplejson.dump(db_data, f, indent=4) try: db_envkey = host.db2docker([db_name]) except: raise RuntimeError("No environment found for the database `%s' " \ "- available environments are %s" % ( db_name, ", ".join(host.db_environments.keys()))) # Creation of the container # Note: we only support one databases image loaded at the same time CONTAINER_PREFIX = '/beat/prefix' CONTAINER_CACHE = '/beat/cache' database_port = random.randint(51000, 60000) if cmd == CMD_VIEW_OUTPUTS: db_cmd = [ cmd, '0.0.0.0:{}'.format(database_port), CONTAINER_PREFIX, CONTAINER_CACHE ] else: db_cmd = [ cmd, CONTAINER_PREFIX, CONTAINER_CACHE, db_name, protocol_name, set_name ] databases_container = host.create_container(db_envkey, db_cmd) if cmd == CMD_VIEW_OUTPUTS: databases_container.add_port(database_port, database_port, host_address=host.ip) databases_container.add_volume(db_tempdir, '/beat/prefix') databases_container.add_volume(configuration.cache, '/beat/cache') # Specify the volumes to mount inside the container if not db_configuration.has_key('datasets_root_path'): databases_container.add_volume(database_path, os.path.join('/databases', db_name)) else: databases_container.add_volume(db_configuration['datasets_root_path'], db_configuration['datasets_root_path']) # Start the container host.start(databases_container) if cmd == CMD_VIEW_OUTPUTS: # Communicate with container zmq_context = zmq.Context() db_socket = zmq_context.socket(zmq.PAIR) db_address = 'tcp://{}:{}'.format(host.ip, database_port) db_socket.connect(db_address) for output_name, dataformat_name in db_set['outputs'].items(): if (excluded_outputs is not None) and (output_name in excluded_outputs): continue data_source = RemoteDataSource() data_source.setup(db_socket, output_name, dataformat_name, configuration.path) input = inputs.Input(output_name, database.dataformats[dataformat_name], data_source) input_group.add(input) return (databases_container, db_socket, zmq_context, input_list) return databases_container #---------------------------------------------------------- def pull(webapi, prefix, names, force, indentation, format_cache): """Copies databases (and required dataformats) from the server. Parameters: webapi (object): An instance of our WebAPI class, prepared to access the BEAT server of interest prefix (str): A string representing the root of the path in which the user objects are stored names (list): A list of strings, each representing the unique relative path of the objects to retrieve or a list of usernames from which to retrieve objects. If the list is empty, then we pull all available objects of a given type. If no user is set, then pull all public objects of a given type. force (bool): If set to ``True``, then overwrites local changes with the remotely retrieved copies. indentation (int): The indentation level, useful if this function is called recursively while downloading different object types. This is normally set to ``0`` (zero). format_cache (dict): A dictionary containing all dataformats already downloaded. Returns: int: Indicating the exit status of the command, to be reported back to the calling process. This value should be zero if everything works OK, otherwise, different than zero (POSIX compliance). """ from .dataformats import pull as dataformats_pull status, names = common.pull(webapi, prefix, 'database', names, ['declaration', 'code', 'description'], force, indentation) # see what dataformats one needs to pull indent = indentation * ' ' dataformats = [] for name in names: obj = Database(prefix, name) dataformats.extend(obj.dataformats.keys()) # downloads any formats to which we depend on df_status = dataformats_pull(webapi, prefix, dataformats, force, indentation + 2, format_cache) return status + df_status #---------------------------------------------------------- def index_outputs(configuration, names, uid=None, db_root=None, docker=False): names = common.make_up_local_list(configuration.path, 'database', names) retcode = 0 if docker: host = dock.Host(raise_on_errors=False) for database_name in names: logger.info("Indexing database %s...", database_name) (db_name, database, sets) = load_database_sets(configuration, database_name) if database is None: retcode += 1 continue for protocol_name, set_name, db_set in sets: if not docker: view = database.view(protocol_name, set_name) if view is None: retcode += 1 continue dataset_hash = hashDataset(db_name, protocol_name, set_name) view.index(os.path.join(configuration.cache, toPath(dataset_hash, '.db'))) else: databases_container = \ start_db_container(configuration, CMD_DB_INDEX, host, db_name, protocol_name, set_name, database, db_set, uid=uid, db_root=db_root ) status = host.wait(databases_container) if status != 0: retcode += 1 return retcode #---------------------------------------------------------- def list_index_files(configuration, names): names = common.make_up_local_list(configuration.path, 'database', names) retcode = 0 for database_name in names: logger.info("Listing database %s indexes...", database_name) (db_name, database, sets) = load_database_sets(configuration, database_name) if database is None: retcode += 1 continue for protocol_name, set_name, db_set in sets: for output_name in db_set['outputs'].keys(): index_hash = database.hash_output(protocol_name, set_name, output_name) index_filename = os.path.join(configuration.cache, toPath(index_hash, '.index')) basename = os.path.splitext(index_filename)[0] for g in glob.glob(basename + '.*'): logger.info(g) return retcode #---------------------------------------------------------- def delete_index_files(configuration, names): names = common.make_up_local_list(configuration.path, 'database', names) retcode = 0 for database_name in names: logger.info("Deleting database %s indexes...", database_name) (db_name, database, sets) = load_database_sets(configuration, database_name) if database is None: retcode += 1 continue for protocol_name, set_name, db_set in sets: for output_name in db_set['outputs'].keys(): index_hash = database.hash_output(protocol_name, set_name, output_name) index_filename = os.path.join(configuration.cache, toPath(index_hash, '.index')) basename = os.path.splitext(index_filename)[0] for g in glob.glob(basename + '.*'): logger.info("removing `%s'...", g) os.unlink(g) common.recursive_rmdir_if_empty(os.path.dirname(basename), configuration.cache) return retcode #---------------------------------------------------------- def checksum_index_files(configuration, names): names = common.make_up_local_list(configuration.path, 'database', names) retcode = 0 for database_name in names: logger.info("Checksumming database %s indexes...", database_name) (db_name, database, sets) = load_database_sets(configuration, database_name) if database is None: retcode += 1 continue for protocol_name, set_name, db_set in sets: for output_name in db_set['outputs'].keys(): index_hash = database.hash_output(protocol_name, set_name, output_name) index_filename = os.path.join(configuration.cache, toPath(index_hash, '.index')) assert load_data_index(configuration.cache, toPath(index_hash, '.index')) logger.info("index for `%s' can be loaded and checksumed", index_filename) return retcode #---------------------------------------------------------- def view_outputs(configuration, dataset_name, excluded_outputs=None, uid=None, db_root=None, docker=False): def data_to_json(data, indent): value = common.stringify(data.as_dict()) value = simplejson.dumps(value, indent=4, cls=NumpyJSONEncoder) \ .replace('"BEAT_LIST_DELIMITER[', '[') \ .replace(']BEAT_LIST_DELIMITER"', ']') \ .replace('"...",', '...') \ .replace('"BEAT_LIST_SIZE(', '(') \ .replace(')BEAT_LIST_SIZE"', ')') return ('\n' + ' ' * indent).join(value.split('\n')) # Load the infos about the database set (db_name, database, sets) = load_database_sets(configuration, dataset_name) if (database is None) or (len(sets) != 1): return 1 (protocol_name, set_name, db_set) = sets[0] if excluded_outputs is not None: excluded_outputs = map(lambda x: x.strip(), excluded_outputs.split(',')) # Setup the view so the outputs can be used if not docker: view = database.view(protocol_name, set_name) if view is None: return 1 dataset_hash = hashDataset(db_name, protocol_name, set_name) view.setup(os.path.join(configuration.cache, toPath(dataset_hash, '.db')), pack=False) input_group = inputs.InputGroup(set_name, restricted_access=False) for output_name, dataformat_name in db_set['outputs'].items(): if (excluded_outputs is not None) and (output_name in excluded_outputs): continue input = inputs.Input(output_name, database.dataformats[dataformat_name], view.data_sources[output_name]) input_group.add(input) else: host = dock.Host(raise_on_errors=False) (databases_container, db_socket, zmq_context, input_list) = \ start_db_container(configuration, CMD_VIEW_OUTPUTS, host, db_name, protocol_name, set_name, database, db_set, excluded_outputs=excluded_outputs, uid=uid, db_root=db_root ) input_group = input_list.group(set_name) # Display the data try: previous_start = -1 while input_group.hasMoreData(): input_group.next() start = input_group.data_index end = input_group.data_index_end if start != previous_start: print(80 * '-') print 'FROM %d TO %d' % (start, end) whole_inputs = [ input for input in input_group if (input.data_index == start) and \ (input.data_index_end == end) ] for input in whole_inputs: label = ' - ' + str(input.name) + ': ' print label + data_to_json(input.data, len(label)) previous_start = start selected_inputs = [ input for input in input_group if (input.data_index == input_group.first_data_index) and \ ((input.data_index != start) or \ (input.data_index_end != end)) ] grouped_inputs = {} for input in selected_inputs: key = (input.data_index, input.data_index_end) if not grouped_inputs.has_key(key): grouped_inputs[key] = [] grouped_inputs[key].append(input) sorted_keys = grouped_inputs.keys() sorted_keys.sort() for key in sorted_keys: print print ' FROM %d TO %d' % key for input in grouped_inputs[key]: label = ' - ' + str(input.name) + ': ' print label + data_to_json(input.data, len(label)) except Exception as e: logger.error("Failed to retrieve the next data: %s", e) return 1 return 0 #---------------------------------------------------------- def process(args): if args['list']: if args['--remote']: with common.make_webapi(args['config']) as webapi: return common.display_remote_list(webapi, 'database') else: return common.display_local_list(args['config'].path, 'database') elif args['check']: return common.check(args['config'].path, 'database', args['']) elif args['pull']: with common.make_webapi(args['config']) as webapi: return pull(webapi, args['config'].path, args[''], args['--force'], 0, {}) elif args['push']: with common.make_webapi(args['config']) as webapi: return common.push(webapi, args['config'].path, 'database', args[''], ['name', 'declaration', 'code', 'description'], {}, args['--force'], args['--dry-run'], 0) elif args['diff']: with common.make_webapi(args['config']) as webapi: return common.diff(webapi, args['config'].path, 'database', args[''][0], ['declaration', 'code', 'description']) elif args['status']: with common.make_webapi(args['config']) as webapi: return common.status(webapi, args['config'].path, 'database')[0] elif args['version']: return common.new_version(args['config'].path, 'database', args[''][0]) elif args['view']: if args['--exclude']: return view_outputs(args['config'], args[''], args['--exclude'], uid=int(args['--uid']) if args['--uid'] is not None else None, db_root=args['--db-root'], docker=args['--docker']) else: return view_outputs(args['config'], args[''], uid=int(args['--uid']) if args['--uid'] is not None else None, db_root=args['--db-root'], docker=args['--docker']) elif args['index']: if args['--list']: return list_index_files(args['config'], args['']) elif args['--delete']: return delete_index_files(args['config'], args['']) elif args['--checksum']: return checksum_index_files(args['config'], args['']) else: return index_outputs(args['config'], args[''], uid=int(args['--uid']) if args['--uid'] is not None else None, db_root=args['--db-root'], docker=args['--docker']) # Should not happen logger.error("unrecognized `databases' subcommand") return 1