databases.py 24.6 KB
Newer Older
André Anjos's avatar
André Anjos committed
1 2 3
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
###################################################################################
#                                                                                 #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/               #
# Contact: beat.support@idiap.ch                                                  #
#                                                                                 #
# Redistribution and use in source and binary forms, with or without              #
# modification, are permitted provided that the following conditions are met:     #
#                                                                                 #
# 1. Redistributions of source code must retain the above copyright notice, this  #
# list of conditions and the following disclaimer.                                #
#                                                                                 #
# 2. Redistributions in binary form must reproduce the above copyright notice,    #
# this list of conditions and the following disclaimer in the documentation       #
# and/or other materials provided with the distribution.                          #
#                                                                                 #
# 3. Neither the name of the copyright holder nor the names of its contributors   #
# may be used to endorse or promote products derived from this software without   #
# specific prior written permission.                                              #
#                                                                                 #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED   #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE          #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE    #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL      #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR      #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER      #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,   #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE   #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.            #
#                                                                                 #
###################################################################################

André Anjos's avatar
André Anjos committed
36 37 38

import glob
import logging
Samuel GAIST's avatar
Samuel GAIST committed
39 40 41 42
import os
import random

import click
André Anjos's avatar
André Anjos committed
43
import simplejson
Samuel GAIST's avatar
Samuel GAIST committed
44
import zmq
André Anjos's avatar
André Anjos committed
45

46 47 48
from beat.core import dock
from beat.core import inputs
from beat.core import utils
Samuel GAIST's avatar
Samuel GAIST committed
49 50 51 52 53
from beat.core.data import RemoteDataSource
from beat.core.database import Database
from beat.core.hash import hashDataset
from beat.core.hash import toPath
from beat.core.utils import NumpyJSONEncoder
André Anjos's avatar
André Anjos committed
54

55
from . import commands
Samuel GAIST's avatar
Samuel GAIST committed
56
from . import common
57
from .click_helper import AliasedGroup
58 59
from .click_helper import AssetCommand
from .click_helper import AssetInfo
Samuel GAIST's avatar
Samuel GAIST committed
60
from .decorators import raise_on_error
André Anjos's avatar
André Anjos committed
61

Samuel GAIST's avatar
Samuel GAIST committed
62
logger = logging.getLogger(__name__)
André Anjos's avatar
André Anjos committed
63

64

65 66
CMD_DB_INDEX = "index"
CMD_VIEW_OUTPUTS = "databases_provider"
67 68


Samuel GAIST's avatar
Samuel GAIST committed
69
# ----------------------------------------------------------
70 71 72 73


def load_database_sets(configuration, database_name):
    # Process the name of the database
74
    parts = database_name.split("/")
75 76

    if len(parts) == 2:
77 78 79
        db_name = os.path.join(*parts[:2])
        protocol_filter = None
        set_filter = None
80 81

    elif len(parts) == 3:
82 83 84
        db_name = os.path.join(*parts[:2])
        protocol_filter = parts[2]
        set_filter = None
85 86

    elif len(parts) == 4:
87 88 89
        db_name = os.path.join(*parts[:2])
        protocol_filter = parts[2]
        set_filter = parts[3]
90 91

    else:
92 93 94 95 96 97
        logger.error(
            "Database specification should have the format "
            "`<database>/<version>/[<protocol>/[<set>]]', the value "
            "you passed (%s) is not valid",
            database_name,
        )
98
        return (None, None)
99 100 101

    # Load the dataformat
    dataformat_cache = {}
102
    database = Database(configuration.path, db_name, dataformat_cache)
103
    if not database.valid:
104 105
        logger.error("Failed to load the database `%s':", db_name)
        for e in database.errors:
106
            logger.error("  * %s", e)
107
        return (None, None, None)
108 109 110 111 112

    # Filter the protocols
    protocols = database.protocol_names

    if protocol_filter is not None:
113
        if protocol_filter not in protocols:
114 115 116 117 118 119 120
            logger.error(
                "The database `%s' does not have the protocol `%s' - "
                "choose one of `%s'",
                db_name,
                protocol_filter,
                ", ".join(protocols),
            )
121

122
            return (None, None, None)
123

124
        protocols = [protocol_filter]
125 126 127 128 129

    # Filter the sets
    loaded_sets = []

    for protocol_name in protocols:
130
        sets = database.set_names(protocol_name)
131

132 133
        if set_filter is not None:
            if set_filter not in sets:
134 135 136 137 138 139 140 141
                logger.error(
                    "The database/protocol `%s/%s' does not have the "
                    "set `%s' - choose one of `%s'",
                    db_name,
                    protocol_name,
                    set_filter,
                    ", ".join(sets),
                )
142
                return (None, None, None)
143

144
            sets = [z for z in sets if z == set_filter]
145

146 147 148 149 150 151
        loaded_sets.extend(
            [
                (protocol_name, set_name, database.set(protocol_name, set_name))
                for set_name in sets
            ]
        )
152 153 154 155

    return (db_name, database, loaded_sets)


Samuel GAIST's avatar
Samuel GAIST committed
156
# ----------------------------------------------------------
157 158


159 160 161 162 163 164 165 166 167 168 169 170 171
def start_db_container(
    configuration,
    cmd,
    host,
    db_name,
    protocol_name,
    set_name,
    database,
    db_set,
    excluded_outputs=None,
    uid=None,
    db_root=None,
):
172 173 174 175 176 177

    input_list = inputs.InputList()

    input_group = inputs.InputGroup(set_name, restricted_access=False)
    input_list.add(input_group)

178
    db_configuration = {"inputs": {}, "channel": set_name}
179

180 181 182 183
    if uid is None:
        uid = os.getuid()

    db_configuration["datasets_uid"] = uid
184 185

    if db_root is not None:
186
        db_configuration["datasets_root_path"] = db_root
187

188
    for output_name, dataformat_name in db_set["outputs"].items():
Samuel GAIST's avatar
Samuel GAIST committed
189
        if excluded_outputs is not None and output_name in excluded_outputs:
190 191
            continue

192
        dataset_hash = hashDataset(db_name, protocol_name, set_name)
193
        db_configuration["inputs"][output_name] = dict(
194 195 196 197 198 199
            database=db_name,
            protocol=protocol_name,
            set=set_name,
            output=output_name,
            channel=set_name,
            hash=dataset_hash,
200
            path=toPath(dataset_hash, ".db"),
201 202 203 204
        )

    db_tempdir = utils.temporary_directory()

205
    with open(os.path.join(db_tempdir, "configuration.json"), "wt") as f:
206 207
        simplejson.dump(db_configuration, f, indent=4)

208
    tmp_prefix = os.path.join(db_tempdir, "prefix")
209 210 211
    if not os.path.exists(tmp_prefix):
        os.makedirs(tmp_prefix)

212
    database.export(tmp_prefix)
213 214

    if db_root is None:
215
        json_path = os.path.join(tmp_prefix, "databases", db_name + ".json")
216

217
        with open(json_path, "r") as f:
218 219
            db_data = simplejson.load(f)

220 221
        database_path = db_data["root_folder"]
        db_data["root_folder"] = os.path.join("/databases", db_name)
222

223
        with open(json_path, "w") as f:
224 225
            simplejson.dump(db_data, f, indent=4)

226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
    environment = database.environment
    if environment:
        environment_name = utils.build_env_name(environment)
        try:
            db_envkey = host.dbenv2docker(environment_name)
        except KeyError:
            raise RuntimeError(
                "Environment {} not found for the database '{}' "
                "- available environments are {}".format(
                    environment_name, db_name, ", ".join(host.db_environments.keys())
                )
            )
    else:
        try:
            db_envkey = host.db2docker([db_name])
        except Exception:
            raise RuntimeError(
                "No environment found for the database `%s' "
                "- available environments are %s"
                % (db_name, ", ".join(host.db_environments.keys()))
            )

    logger.info("Indexing using {}".format(db_envkey))
249 250 251

    # Creation of the container
    # Note: we only support one databases image loaded at the same time
252 253
    CONTAINER_PREFIX = "/beat/prefix"
    CONTAINER_CACHE = "/beat/cache"
254

255
    database_port = random.randint(51000, 60000)  # nosec just getting a free port
256
    if cmd == CMD_VIEW_OUTPUTS:
257 258
        db_cmd = [
            cmd,
259
            "0.0.0.0:{}".format(database_port),
260
            CONTAINER_PREFIX,
261
            CONTAINER_CACHE,
262
        ]
263
    else:
264 265 266 267 268 269
        db_cmd = [
            cmd,
            CONTAINER_PREFIX,
            CONTAINER_CACHE,
            db_name,
            protocol_name,
270
            set_name,
271
        ]
272 273

    databases_container = host.create_container(db_envkey, db_cmd)
274 275
    databases_container.uid = uid

276
    if cmd == CMD_VIEW_OUTPUTS:
277 278 279
        databases_container.add_port(database_port, database_port, host_address=host.ip)
        databases_container.add_volume(db_tempdir, "/beat/prefix")
        databases_container.add_volume(configuration.cache, "/beat/cache")
280
    else:
281 282 283 284
        databases_container.add_volume(tmp_prefix, "/beat/prefix")
        databases_container.add_volume(
            configuration.cache, "/beat/cache", read_only=False
        )
285 286

    # Specify the volumes to mount inside the container
287
    if "datasets_root_path" not in db_configuration:
288
        databases_container.add_volume(
289 290
            database_path, os.path.join("/databases", db_name)
        )
291
    else:
292 293 294 295
        databases_container.add_volume(
            db_configuration["datasets_root_path"],
            db_configuration["datasets_root_path"],
        )
296 297 298 299

    # Start the container
    host.start(databases_container)

300
    if cmd == CMD_VIEW_OUTPUTS:
301 302 303
        # Communicate with container
        zmq_context = zmq.Context()
        db_socket = zmq_context.socket(zmq.PAIR)
304
        db_address = "tcp://{}:{}".format(host.ip, database_port)
305
        db_socket.connect(db_address)
306

307 308
        for output_name, dataformat_name in db_set["outputs"].items():
            if excluded_outputs is not None and output_name in excluded_outputs:
309
                continue
310

311
            data_source = RemoteDataSource()
312 313 314
            data_source.setup(
                db_socket, output_name, dataformat_name, configuration.path
            )
315

316 317 318
            input_ = inputs.Input(
                output_name, database.dataformats[dataformat_name], data_source
            )
Samuel GAIST's avatar
Samuel GAIST committed
319
            input_group.add(input_)
320

321
        return (databases_container, db_socket, zmq_context, input_list)
322 323

    return databases_container
324 325


Samuel GAIST's avatar
Samuel GAIST committed
326
# ----------------------------------------------------------
327 328


329
def pull_impl(webapi, prefix, names, force, indentation, format_cache):
330
    """Copies databases (and required dataformats) from the server.
André Anjos's avatar
André Anjos committed
331

332
    Parameters:
André Anjos's avatar
André Anjos committed
333

334 335
      webapi (object): An instance of our WebAPI class, prepared to access the
        BEAT server of interest
André Anjos's avatar
André Anjos committed
336

Samuel GAIST's avatar
Samuel GAIST committed
337 338
      prefix (str): A string representing the root of the path in which the
        user objects are stored
André Anjos's avatar
André Anjos committed
339

André Anjos's avatar
André Anjos committed
340 341 342 343 344
      names (:py:class:`list`): A list of strings, each representing the unique
        relative path of the objects to retrieve or a list of usernames from
        which to retrieve objects. If the list is empty, then we pull all
        available objects of a given type. If no user is set, then pull all
        public objects of a given type.
André Anjos's avatar
André Anjos committed
345

346 347
      force (bool): If set to ``True``, then overwrites local changes with the
        remotely retrieved copies.
André Anjos's avatar
André Anjos committed
348

Samuel GAIST's avatar
Samuel GAIST committed
349 350 351
      indentation (int): The indentation level, useful if this function is
        called recursively while downloading different object types. This is
        normally set to ``0`` (zero).
André Anjos's avatar
André Anjos committed
352

353 354
      format_cache (dict): A dictionary containing all dataformats already
        downloaded.
André Anjos's avatar
André Anjos committed
355 356


357
    Returns:
André Anjos's avatar
André Anjos committed
358

Samuel GAIST's avatar
Samuel GAIST committed
359 360
      int: Indicating the exit status of the command, to be reported back to
        the calling process. This value should be zero if everything works OK,
361
        otherwise, different than zero (POSIX compliance).
André Anjos's avatar
André Anjos committed
362

363
    """
André Anjos's avatar
André Anjos committed
364

365
    from .dataformats import pull_impl as dataformats_pull
André Anjos's avatar
André Anjos committed
366

367 368 369 370 371 372 373 374 375
    status, names = common.pull(
        webapi,
        prefix,
        "database",
        names,
        ["declaration", "code", "description"],
        force,
        indentation,
    )
André Anjos's avatar
André Anjos committed
376

377 378 379 380 381
    # see what dataformats one needs to pull
    dataformats = []
    for name in names:
        obj = Database(prefix, name)
        dataformats.extend(obj.dataformats.keys())
André Anjos's avatar
André Anjos committed
382

383
    # downloads any formats to which we depend on
384 385 386
    df_status = dataformats_pull(
        webapi, prefix, dataformats, force, indentation + 2, format_cache
    )
André Anjos's avatar
André Anjos committed
387

388
    return status + df_status
André Anjos's avatar
André Anjos committed
389 390


Samuel GAIST's avatar
Samuel GAIST committed
391
# ----------------------------------------------------------
André Anjos's avatar
André Anjos committed
392 393


394
def index_outputs(configuration, names, uid=None, db_root=None, docker=False):
André Anjos's avatar
André Anjos committed
395

396
    names = common.make_up_local_list(configuration.path, "database", names)
397
    retcode = 0
André Anjos's avatar
André Anjos committed
398

Philip ABBET's avatar
Philip ABBET committed
399
    if docker:
400
        host = dock.Host(raise_on_errors=False)
André Anjos's avatar
André Anjos committed
401

402
    for database_name in names:
403
        logger.info("Indexing database %s...", database_name)
André Anjos's avatar
André Anjos committed
404

405
        (db_name, database, sets) = load_database_sets(configuration, database_name)
406 407 408
        if database is None:
            retcode += 1
            continue
André Anjos's avatar
André Anjos committed
409

410 411
        for protocol_name, set_name, db_set in sets:
            if not docker:
412 413 414
                try:
                    view = database.view(protocol_name, set_name)
                except SyntaxError as error:
415 416
                    logger.error("Failed to load the database `%s':", database_name)
                    logger.error("  * Syntax error: %s", error)
417
                    view = None
André Anjos's avatar
André Anjos committed
418

419 420 421
                if view is None:
                    retcode += 1
                    continue
422

423
                dataset_hash = hashDataset(db_name, protocol_name, set_name)
424
                try:
425 426 427
                    view.index(
                        os.path.join(configuration.cache, toPath(dataset_hash, ".db"))
                    )
428
                except RuntimeError as error:
429 430 431 432
                    logger.error("Failed to load the database `%s':", database_name)
                    logger.error("  * Runtime error %s", error)
                    retcode += 1
                    continue
433

434
            else:
435 436 437 438 439 440 441 442 443 444 445 446
                databases_container = start_db_container(
                    configuration,
                    CMD_DB_INDEX,
                    host,
                    db_name,
                    protocol_name,
                    set_name,
                    database,
                    db_set,
                    uid=uid,
                    db_root=db_root,
                )
447
                status = host.wait(databases_container)
448 449 450
                logs = host.logs(databases_container)
                host.rm(databases_container)

451
                if status != 0:
452
                    logger.error("Error occurred: %s", logs)
453
                    retcode += 1
André Anjos's avatar
André Anjos committed
454

455
    return retcode
André Anjos's avatar
André Anjos committed
456 457


Samuel GAIST's avatar
Samuel GAIST committed
458
# ----------------------------------------------------------
André Anjos's avatar
André Anjos committed
459 460


461
def list_index_files(configuration, names):
462

463
    names = common.make_up_local_list(configuration.path, "database", names)
464

465
    retcode = 0
466

467 468
    for database_name in names:
        logger.info("Listing database %s indexes...", database_name)
469

470
        (db_name, database, sets) = load_database_sets(configuration, database_name)
471 472 473
        if database is None:
            retcode += 1
            continue
André Anjos's avatar
André Anjos committed
474

475
        for protocol_name, set_name, db_set in sets:
476 477 478
            dataset_hash = hashDataset(db_name, protocol_name, set_name)
            index_filename = toPath(dataset_hash)
            basename = os.path.splitext(index_filename)[0]
479
            for g in glob.glob(basename + ".*"):
480
                logger.info(g)
481

482
    return retcode
483 484


Samuel GAIST's avatar
Samuel GAIST committed
485
# ----------------------------------------------------------
486 487


488
def delete_index_files(configuration, names):
489

490
    names = common.make_up_local_list(configuration.path, "database", names)
491

492
    retcode = 0
493

494 495
    for database_name in names:
        logger.info("Deleting database %s indexes...", database_name)
496

497
        (db_name, database, sets) = load_database_sets(configuration, database_name)
498 499 500
        if database is None:
            retcode += 1
            continue
501

502
        for protocol_name, set_name, db_set in sets:
503
            for output_name in db_set["outputs"].keys():
504 505
                dataset_hash = hashDataset(db_name, protocol_name, set_name)
                index_filename = toPath(dataset_hash)
506 507 508
                basename = os.path.join(
                    configuration.cache, os.path.splitext(index_filename)[0]
                )
509

510
                for g in glob.glob(basename + ".*"):
511 512
                    logger.info("removing `%s'...", g)
                    os.unlink(g)
513

514 515 516
                common.recursive_rmdir_if_empty(
                    os.path.dirname(basename), configuration.cache
                )
517

518
    return retcode
519 520


Samuel GAIST's avatar
Samuel GAIST committed
521
# ----------------------------------------------------------
522

André Anjos's avatar
André Anjos committed
523

524 525 526 527 528 529 530 531
def view_outputs(
    configuration,
    dataset_name,
    excluded_outputs=None,
    uid=None,
    db_root=None,
    docker=False,
):
532 533
    def data_to_json(data, indent):
        value = common.stringify(data.as_dict())
André Anjos's avatar
André Anjos committed
534

535 536 537 538 539 540 541 542
        value = (
            simplejson.dumps(value, indent=4, cls=NumpyJSONEncoder)
            .replace('"BEAT_LIST_DELIMITER[', "[")
            .replace(']BEAT_LIST_DELIMITER"', "]")
            .replace('"...",', "...")
            .replace('"BEAT_LIST_SIZE(', "(")
            .replace(')BEAT_LIST_SIZE"', ")")
        )
André Anjos's avatar
André Anjos committed
543

544
        return ("\n" + " " * indent).join(value.split("\n"))
André Anjos's avatar
André Anjos committed
545

546 547 548 549
    # Load the infos about the database set
    (db_name, database, sets) = load_database_sets(configuration, dataset_name)
    if (database is None) or (len(sets) != 1):
        return 1
André Anjos's avatar
André Anjos committed
550

551 552 553
    (protocol_name, set_name, db_set) = sets[0]

    if excluded_outputs is not None:
554
        excluded_outputs = map(lambda x: x.strip(), excluded_outputs.split(","))
André Anjos's avatar
André Anjos committed
555

556 557
    # Setup the view so the outputs can be used
    if not docker:
558
        view = database.view(protocol_name, set_name)
559

560 561
        if view is None:
            return 1
562 563

        dataset_hash = hashDataset(db_name, protocol_name, set_name)
564 565 566
        view.setup(
            os.path.join(configuration.cache, toPath(dataset_hash, ".db")), pack=False
        )
567 568
        input_group = inputs.InputGroup(set_name, restricted_access=False)

569 570
        for output_name, dataformat_name in db_set["outputs"].items():
            if excluded_outputs is not None and output_name in excluded_outputs:
571 572
                continue

573 574 575 576 577
            input = inputs.Input(
                output_name,
                database.dataformats[dataformat_name],
                view.data_sources[output_name],
            )
578 579
            input_group.add(input)

580 581 582
    else:
        host = dock.Host(raise_on_errors=False)

583 584 585 586 587 588 589 590 591 592 593 594 595
        (databases_container, db_socket, zmq_context, input_list) = start_db_container(
            configuration,
            CMD_VIEW_OUTPUTS,
            host,
            db_name,
            protocol_name,
            set_name,
            database,
            db_set,
            excluded_outputs=excluded_outputs,
            uid=uid,
            db_root=db_root,
        )
596

597
        input_group = input_list.group(set_name)
André Anjos's avatar
André Anjos committed
598

599 600
    retvalue = 0

601 602 603
    # Display the data
    try:
        previous_start = -1
André Anjos's avatar
André Anjos committed
604

605 606
        while input_group.hasMoreData():
            input_group.next()
André Anjos's avatar
André Anjos committed
607

608 609
            start = input_group.data_index
            end = input_group.data_index_end
André Anjos's avatar
André Anjos committed
610

611
            if start != previous_start:
612
                print(80 * "-")
André Anjos's avatar
André Anjos committed
613

614
                print("FROM %d TO %d" % (start, end))
André Anjos's avatar
André Anjos committed
615

616 617 618 619 620
                whole_inputs = [
                    input_
                    for input_ in input_group
                    if input_.data_index == start and input_.data_index_end == end
                ]
André Anjos's avatar
André Anjos committed
621

622
                for input in whole_inputs:
623
                    label = " - " + str(input.name) + ": "
624
                    print(label + data_to_json(input.data, len(label)))
André Anjos's avatar
André Anjos committed
625

626
                previous_start = start
André Anjos's avatar
André Anjos committed
627

628 629 630 631 632 633
            selected_inputs = [
                input_
                for input_ in input_group
                if input_.data_index == input_group.first_data_index
                and (input_.data_index != start or input_.data_index_end != end)
            ]
André Anjos's avatar
André Anjos committed
634

635
            grouped_inputs = {}
Samuel GAIST's avatar
Samuel GAIST committed
636 637 638
            for input_ in selected_inputs:
                key = (input_.data_index, input_.data_index_end)
                if key not in grouped_inputs:
639 640
                    grouped_inputs[key] = []
                grouped_inputs[key].append(input)
André Anjos's avatar
André Anjos committed
641

642
            sorted_keys = sorted(grouped_inputs.keys())
643 644 645

            for key in sorted_keys:
                print
646
                print("  FROM %d TO %d" % key)
647 648

                for input in grouped_inputs[key]:
649
                    label = "   - " + str(input.name) + ": "
650
                    print(label + data_to_json(input.data, len(label)))
André Anjos's avatar
André Anjos committed
651 652

    except Exception as e:
653
        logger.error("Failed to retrieve the next data: %s", e)
654 655 656 657 658 659 660 661 662 663 664
        retvalue = 1

    if docker:
        host.kill(databases_container)
        status = host.wait(databases_container)
        logs = host.logs(databases_container)
        host.rm(databases_container)
        if status != 0:
            logger.error("Docker error: %s", logs)

    return retvalue
André Anjos's avatar
André Anjos committed
665

666

Samuel GAIST's avatar
Samuel GAIST committed
667
# ----------------------------------------------------------
André Anjos's avatar
André Anjos committed
668 669


670 671 672 673 674 675 676 677
class DatabaseCommand(AssetCommand):
    asset_info = AssetInfo(
        asset_type="database",
        diff_fields=["declaration", "code", "description"],
        push_fields=["name", "declaration", "code", "description"],
    )


678
@click.group(cls=AliasedGroup)
679
@click.pass_context
680
def databases(ctx):
681
    """Database commands"""
682

683

684 685 686 687 688 689 690 691 692 693
CMD_LIST = [
    "list",
    "path",
    "edit",
    "check",
    "status",
    "create",
    "version",
    ("rm", "rm_local"),
    "diff",
694
    "push",
695 696
]

697
commands.initialise_asset_commands(databases, CMD_LIST, DatabaseCommand)
698 699


700
@databases.command()
701 702 703 704
@click.argument("db_names", nargs=-1)
@click.option(
    "--force", help="Performs operation regardless of conflicts", is_flag=True
)
705
@click.pass_context
706
@raise_on_error
707
def pull(ctx, db_names, force):
708
    """Downloads the specified databases from the server.
709

710
       $ beat databases pull [<name>]...
711 712 713

    <name>:
        Database name formatted as "<database>/<version>"
714 715
    """
    configuration = ctx.meta["config"]
716 717 718 719
    with common.make_webapi(configuration) as webapi:
        return pull_impl(webapi, configuration.path, db_names, force, 0, {})


720
@databases.command()
721 722 723 724 725 726 727 728 729 730 731 732 733 734
@click.argument("db_names", nargs=-1)
@click.option(
    "--list", help="List index files matching output if they exist", is_flag=True
)
@click.option(
    "--delete",
    help="Delete index files matching output if they "
    "exist (also, recursively deletes empty directories)",
    is_flag=True,
)
@click.option("--checksum", help="Checksums index files", is_flag=True, default=True)
@click.option("--uid", type=click.INT, default=None)
@click.option("--db-root", help="Database root")
@click.option("--docker", is_flag=True)
735
@click.pass_context
736
@raise_on_error
737
def index(ctx, db_names, list, delete, checksum, uid, db_root, docker):
738
    """Indexes all outputs (of all sets) of a database.
739 740 741

    To index the contents of a database

742
        $ beat databases index simple/1
743 744 745

    To index the contents of a protocol on a database

746
        $ beat databases index simple/1/double
747 748 749

    To index the contents of a set in a protocol on a database

750
        $ beat databases index simple/1/double/double
751 752
    """
    configuration = ctx.meta["config"]
753
    code = 1
754
    if list:
755
        code = list_index_files(configuration, db_names)
756
    elif delete:
757
        code = delete_index_files(configuration, db_names)
758
    elif checksum:
759 760 761
        code = index_outputs(
            configuration, db_names, uid=uid, db_root=db_root, docker=docker
        )
762
    return code
763

764

765
@databases.command()
766 767 768 769 770
@click.argument("set_name", nargs=1)
@click.option("--exclude", help="When viewing, excludes this output", default=None)
@click.option("--uid", type=click.INT, default=None)
@click.option("--db-root", help="Database root")
@click.option("--docker", is_flag=True)
771
@click.pass_context
772
@raise_on_error
773
def view(ctx, set_name, exclude, uid, db_root, docker):
774
    """View the data of the specified dataset.
775 776 777

    To view the contents of a specific set

778
    $ beat databases view simple/1/protocol/set
779 780
    """
    configuration = ctx.meta["config"]
781 782
    if exclude is not None:
        return view_outputs(
783 784
            configuration, set_name, exclude, uid=uid, db_root=db_root, docker=docker
        )
785 786 787
    return view_outputs(
        configuration, set_name, uid=uid, db_root=db_root, docker=docker
    )