experiments.py 31 KB
Newer Older
André Anjos's avatar
André Anjos committed
1
2
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
André Anjos's avatar
André Anjos committed
3

Samuel GAIST's avatar
Samuel GAIST committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
###################################################################################
#                                                                                 #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/               #
# Contact: beat.support@idiap.ch                                                  #
#                                                                                 #
# Redistribution and use in source and binary forms, with or without              #
# modification, are permitted provided that the following conditions are met:     #
#                                                                                 #
# 1. Redistributions of source code must retain the above copyright notice, this  #
# list of conditions and the following disclaimer.                                #
#                                                                                 #
# 2. Redistributions in binary form must reproduce the above copyright notice,    #
# this list of conditions and the following disclaimer in the documentation       #
# and/or other materials provided with the distribution.                          #
#                                                                                 #
# 3. Neither the name of the copyright holder nor the names of its contributors   #
# may be used to endorse or promote products derived from this software without   #
# specific prior written permission.                                              #
#                                                                                 #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED   #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE          #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE    #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL      #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR      #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER      #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,   #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE   #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.            #
#                                                                                 #
###################################################################################
André Anjos's avatar
André Anjos committed
35
36
37
38
39


import os
import logging
import glob
40
import click
André Anjos's avatar
André Anjos committed
41
42
43
44
import oset
import simplejson

from beat.core.experiment import Experiment
45
46
from beat.core.execution import DockerExecutor
from beat.core.execution import LocalExecutor
André Anjos's avatar
André Anjos committed
47
48
from beat.core.utils import NumpyJSONEncoder
from beat.core.data import CachedDataSource, load_data_index
49
from beat.core.dock import Host
50
51
52
from beat.core.hash import toPath
from beat.core.hash import hashDataset

53
from . import common
54
55
from . import commands

56
57
from .plotters import plot_impl as plotters_plot
from .plotters import pull_impl as plotters_pull
58
from .decorators import raise_on_error
59
60
from .click_helper import AliasedGroup

61
62

logger = logging.getLogger(__name__)
André Anjos's avatar
André Anjos committed
63
64


65
def run_experiment(configuration, name, force, use_docker, use_local, quiet):
66
    """Run experiments locally"""
67
68

    def load_result(executor):
69
        """Loads the result of an experiment, in a single go"""
70
71

        f = CachedDataSource()
72
73
74
75
76
77
78
79
        success = f.setup(
            os.path.join(executor.cache, executor.data["result"]["path"] + ".data"),
            executor.prefix,
        )

        if not success:
            raise RuntimeError("Failed to setup cached data source")

80
81
82
        data, start, end = f[0]
        return data

83
84
    def print_results(executor):
        data = load_result(executor)
85
86
87
        r = reindent(
            simplejson.dumps(data.as_dict(), indent=2, cls=NumpyJSONEncoder), 2
        )
88
89
        logger.info("  Results:\n%s", r)

90
    def reindent(s, n):
91
92
93
        """Re-indents output so it is more visible"""
        margin = n * " "
        return margin + ("\n" + margin).join(s.split("\n"))
94
95

    def simplify_time(s):
96
        """Re-writes the time so it is easier to understand it"""
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

        minute = 60.0
        hour = 60 * minute
        day = 24 * hour

        if s <= minute:
            return "%.2f s" % s
        elif s <= hour:
            minutes = s // minute
            seconds = s - (minute * minutes)
            return "%d m %.2f s" % (minutes, seconds)
        elif s <= day:
            hours = s // hour
            minutes = (s - (hour * hours)) // minute
            seconds = s - (hour * hours + minute * minutes)
            return "%d h %d m %.2f s" % (hours, minutes, seconds)
        else:
            days = s // day
            hours = (s - (day * days)) // hour
            minutes = (s - (day * days + hour * hours)) // minute
            seconds = s - (day * days + hour * hours + minute * minutes)
            return "%d days %d h %d m %.2f s" % (days, hours, minutes, seconds)

    def simplify_size(s):
121
        """Re-writes the size so it is easier to understand it"""
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

        kb = 1024.0
        mb = kb * kb
        gb = kb * mb
        tb = kb * gb

        if s <= kb:
            return "%d bytes" % s
        elif s <= mb:
            return "%.2f kilobytes" % (s / kb)
        elif s <= gb:
            return "%.2f megabytes" % (s / mb)
        elif s <= tb:
            return "%.2f gigabytes" % (s / gb)
        return "%.2f terabytes" % (s / tb)

138
    def index_experiment_databases(cache_path, experiment):
139
        for block_name, infos in experiment.datasets.items():
140
141
142
143
144
            view = infos["database"].view(infos["protocol"], infos["set"])
            filename = toPath(
                hashDataset(infos["database"].name, infos["protocol"], infos["set"]),
                suffix=".db",
            )
145
146
            database_index_path = os.path.join(cache_path, filename)
            if not os.path.exists(database_index_path):
147
148
149
150
                logger.info(
                    "Index for database %s not found, building it",
                    infos["database"].name,
                )
151
                view.index(database_index_path)
152
153
154
155
156
157

    dataformat_cache = {}
    database_cache = {}
    algorithm_cache = {}
    library_cache = {}

158
159
160
161
162
163
164
165
    experiment = Experiment(
        configuration.path,
        name,
        dataformat_cache,
        database_cache,
        algorithm_cache,
        library_cache,
    )
166
167
168
169

    if not experiment.valid:
        logger.error("Failed to load the experiment `%s':", name)
        for e in experiment.errors:
170
            logger.error("  * %s", e)
171
        return 1
André Anjos's avatar
André Anjos committed
172

173
174
175
    if not os.path.exists(configuration.cache):
        os.makedirs(configuration.cache)
        logger.info("Created cache path `%s'", configuration.cache)
176

177
    index_experiment_databases(configuration.cache, experiment)
André Anjos's avatar
André Anjos committed
178

179
    scheduled = experiment.setup()
André Anjos's avatar
André Anjos committed
180

181
    if use_docker:
182
183
184
185
186
187
188
189
190
191
        # load existing environments
        host = Host(raise_on_errors=False)

    # can we execute it?
    for key, value in scheduled.items():

        # checks and sets-up executable
        executable = None  # use the default

        if use_docker:
192
193
            env = value["configuration"]["environment"]
            search_key = "%s (%s)" % (env["name"], env["version"])
194
            if search_key not in host:
195
196
197
198
199
200
                logger.error(
                    "Cannot execute block `%s' on environment `%s': "
                    "environment was not found' - please install it",
                    key,
                    search_key,
                )
201
202
203
                return 1

        if use_docker:
204
205
206
207
208
209
210
211
212
213
            executor = DockerExecutor(
                host,
                configuration.path,
                value["configuration"],
                configuration.cache,
                dataformat_cache,
                database_cache,
                algorithm_cache,
                library_cache,
            )
214
        else:
215
216
217
218
219
220
221
222
223
224
            executor = LocalExecutor(
                configuration.path,
                value["configuration"],
                configuration.cache,
                dataformat_cache,
                database_cache,
                algorithm_cache,
                library_cache,
                configuration.database_paths,
            )
225
226

        if not executor.valid:
227
            logger.error("Failed to load the execution information for `%s':", key)
228
            for e in executor.errors:
229
                logger.error("  * %s", e)
230
231
232
            return 1

        if executor.outputs_exist and not force:
233
234
235
236
237
            logger.info(
                "Skipping execution of `%s' for block `%s' " "- outputs exist",
                executor.algorithm.name,
                key,
            )
238
            if executor.analysis and not quiet:
239
240
                logger.extra("  Outputs produced:")
                print_results(executor)
241
242
            continue

243
        logger.info("Running `%s' for block `%s'", executor.algorithm.name, key)
244
245
246
247
248
249
250
251
        if executable is not None:
            logger.extra("  -> using executable at `%s'", executable)
        else:
            logger.extra("  -> using fallback (default) environment")

        with executor:
            result = executor.process()

252
        if result["status"] != 0:
253
            logger.error("Block did not execute properly - outputs were reset")
254
255
256
257
258
259
260
261
262
            logger.error("  Standard output:\n%s", reindent(result["stdout"], 4))
            logger.error("  Standard error:\n%s", reindent(result["stderr"], 4))
            logger.error(
                "  Captured user error:\n%s", reindent(result["user_error"], 4)
            )
            logger.error(
                "  Captured system error:\n%s", reindent(result["system_error"], 4)
            )
            logger.extra("  Environment: %s" % "default environment")
263
264
            return 1
        elif use_docker:
265
266
267
            stats = result["statistics"]
            cpu_stats = stats["cpu"]
            data_stats = stats["data"]
268

269
            cpu_total = cpu_stats["total"]
270
271
272
273
            # Likely means that GPU was used
            if not cpu_total:
                cpu_total = 1.0

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
            logger.extra(
                "  CPU time (user, system, total, percent): " "%s, %s, %s, %d%%",
                simplify_time(cpu_stats["user"]),
                simplify_time(cpu_stats["system"]),
                simplify_time(cpu_total),
                100.0 * (cpu_stats["user"] + cpu_stats["system"]) / cpu_total,
            )
            logger.extra("  Memory usage: %s", simplify_size(stats["memory"]["rss"]))
            logger.extra(
                "  Cached input read: %s, %s",
                simplify_time(data_stats["time"]["read"]),
                simplify_size(data_stats["volume"]["read"]),
            )
            logger.extra(
                "  Cached output write: %s, %s",
                simplify_time(data_stats["time"]["write"]),
                simplify_size(data_stats["volume"]["write"]),
            )
            logger.extra(
                "  Communication time: %s (%d%%)",
                simplify_time(data_stats["network"]["wait_time"]),
                100.0 * data_stats["network"]["wait_time"] / cpu_total,
            )
297
        else:
298
            logger.extra("  Environment: %s" % "local environment")
299

300
301
302
        if not quiet:
            if executor.analysis:
                print_results(executor)
303

304
305
            logger.extra("  Outputs produced:")
            if executor.analysis:
306
                logger.extra("    * %s", executor.data["result"]["path"])
307
            else:
308
309
                for name, details in executor.data["outputs"].items():
                    logger.extra("    * %s", details["path"])
310
        else:
311
            logger.info("Done")
312
313

    return 0
André Anjos's avatar
André Anjos committed
314
315


316
def caches_impl(configuration, name, ls, delete, checksum):
317
    """List all cache files involved in this experiment"""
André Anjos's avatar
André Anjos committed
318

319
320
321
322
    dataformat_cache = {}
    database_cache = {}
    algorithm_cache = {}
    library_cache = {}
André Anjos's avatar
André Anjos committed
323

324
325
326
327
328
329
330
331
    experiment = Experiment(
        configuration.path,
        name,
        dataformat_cache,
        database_cache,
        algorithm_cache,
        library_cache,
    )
André Anjos's avatar
André Anjos committed
332

333
334
335
    if not experiment.valid:
        logger.error("Failed to load the experiment `%s':", name)
        for e in experiment.errors:
336
            logger.error("  * %s", e)
337
        return 1
André Anjos's avatar
André Anjos committed
338

339
    scheduled = experiment.setup()
André Anjos's avatar
André Anjos committed
340

341
    block_list = []
342
    for key, value in scheduled.items():
343
        block = {
344
345
346
347
            "name": key,
            "algorithm": value["configuration"]["algorithm"],
            "is_analyser": False,
            "paths": [],
348
349
        }

350
351
352
        if "outputs" in value["configuration"]:  # normal block
            for name, data in value["configuration"]["outputs"].items():
                block["paths"].append(data["path"])
353
        else:  # analyzer
354
355
            block["is_analyser"] = True
            block["paths"].append(value["configuration"]["result"]["path"])
André Anjos's avatar
André Anjos committed
356

357
        block_list.append(block)
André Anjos's avatar
André Anjos committed
358

359
    for block in block_list:
360
361
362
        block_type = "analyzer" if block["is_analyser"] else "algorithm"
        logger.info("block: `%s'", block["name"])
        logger.info("  %s: `%s'", block_type, block["algorithm"])
André Anjos's avatar
André Anjos committed
363

364
        for path in block["paths"]:
365
366
367
            # prefix cache path
            path = os.path.join(configuration.cache, path)
            logger.info("  output: `%s'", path)
André Anjos's avatar
André Anjos committed
368

369
            if ls:
370
371
                for file in glob.glob(path + ".*"):
                    logger.info("    %s" % file)
André Anjos's avatar
André Anjos committed
372

373
            if delete:
374
                for file in glob.glob(path + ".*"):
375
376
                    logger.info("removing `%s'...", file)
                    os.unlink(file)
André Anjos's avatar
André Anjos committed
377

378
                common.recursive_rmdir_if_empty(
379
380
                    os.path.dirname(path), configuration.cache
                )
André Anjos's avatar
André Anjos committed
381

382
            if checksum:
383
384
                if not load_data_index(configuration.cache, path + ".data"):
                    logger.error("Failed to load data index for {}".format(path))
385
                logger.info("index for `%s' can be loaded and checksums", path)
386

387
    return 0
André Anjos's avatar
André Anjos committed
388
389


390
def pull_impl(webapi, prefix, names, force, indentation, format_cache):
391
    """Copies experiments (and required toolchains/algorithms) from the server.
André Anjos's avatar
André Anjos committed
392

393
    Parameters:
André Anjos's avatar
André Anjos committed
394

395
396
      webapi (object): An instance of our WebAPI class, prepared to access the
        BEAT server of interest
André Anjos's avatar
André Anjos committed
397

398
399
      prefix (str): A string representing the root of the path in which the
        user objects are stored
André Anjos's avatar
André Anjos committed
400

André Anjos's avatar
André Anjos committed
401
402
403
404
405
      names (:py:class:`list`): A list of strings, each representing the unique
        relative path of the objects to retrieve or a list of usernames from
        which to retrieve objects. If the list is empty, then we pull all
        available objects of a given type. If no user is set, then pull all
        public objects of a given type.
André Anjos's avatar
André Anjos committed
406

407
408
      force (bool): If set to ``True``, then overwrites local changes with the
        remotely retrieved copies.
André Anjos's avatar
André Anjos committed
409

410
411
412
      indentation (int): The indentation level, useful if this function is
        called recursively while downloading different object types. This is
        normally set to ``0`` (zero).
André Anjos's avatar
André Anjos committed
413
414


415
    Returns:
André Anjos's avatar
André Anjos committed
416

417
418
419
      int: Indicating the exit status of the command, to be reported back to
        the calling process. This value should be zero if everything works OK,
        otherwise, different than zero (POSIX compliance).
André Anjos's avatar
André Anjos committed
420

421
    """
André Anjos's avatar
André Anjos committed
422

423
424
    from .algorithms import pull_impl as algorithms_pull
    from .databases import pull_impl as databases_pull
André Anjos's avatar
André Anjos committed
425

426
427
428
429
430
431
432
433
434
435
436
437
    if indentation == 0:
        indentation = 4

    status, names = common.pull(
        webapi,
        prefix,
        "experiment",
        names,
        ["declaration", "description"],
        force,
        indentation,
    )
André Anjos's avatar
André Anjos committed
438

439
    if status != 0:
440
        logger.error("could not find any matching experiments - widen your search")
441
        return status
André Anjos's avatar
André Anjos committed
442

443
444
445
446
447
448
449
450
451
452
453
    # see what dataformats one needs to pull
    databases = oset.oset()
    toolchains = oset.oset()
    algorithms = oset.oset()
    for name in names:
        try:
            obj = Experiment(prefix, name)
            if obj.toolchain:
                toolchains.add(obj.toolchain.name)
            databases |= obj.databases.keys()
            algorithms |= obj.algorithms.keys()
André Anjos's avatar
André Anjos committed
454

455
456
        except Exception as e:
            logger.error("loading `%s': %s...", name, str(e))
André Anjos's avatar
André Anjos committed
457

458
459
460
    # downloads any formats to which we depend on
    format_cache = {}
    library_cache = {}
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    tc_status, _ = common.pull(
        webapi,
        prefix,
        "toolchain",
        toolchains,
        ["declaration", "description"],
        force,
        indentation,
    )
    db_status = databases_pull(
        webapi, prefix, databases, force, indentation, format_cache
    )
    algo_status = algorithms_pull(
        webapi, prefix, algorithms, force, indentation, format_cache, library_cache
    )
André Anjos's avatar
André Anjos committed
476

477
    return status + tc_status + db_status + algo_status
André Anjos's avatar
André Anjos committed
478
479


480
481
482
483
484
485
486
487
488
489
490
491
def plot_impl(
    webapi,
    configuration,
    prefix,
    names,
    remote_results,
    show,
    force,
    indentation,
    format_cache,
    outputfolder=None,
):
492
493
494
495
496
497
498
    """Plots experiments from the server.

    Parameters:

      webapi (object): An instance of our WebAPI class, prepared to access the
        BEAT server of interest

499
500
501
      configuration (object): An instance of the configuration, to access the
        BEAT server and current configuration for information

502
503
504
      prefix (str): A string representing the root of the path in which the
        user objects are stored

505
      names (:py:class:`list`): A list of strings, each representing the unique relative
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
        path of the objects to retrieve or a list of usernames from which to
        retrieve objects. If the list is empty, then we pull all available
        objects of a given type. If no user is set, then pull all public
        objects of a given type.

      remote_results(bool): If set to ``True``, then fetch results data
        for the experiments from the server.

      force (bool): If set to ``True``, then overwrites local changes with the
        remotely retrieved copies.

      indentation (int): The indentation level, useful if this function is
        called recursively while downloading different object types. This is
        normally set to ``0`` (zero).

521
522
      outputfolder (str): A string representing the path in which the
        experiments plot will be stored
523
524
525
526
527
528
529
530
531
532

    Returns:

      int: Indicating the exit status of the command, to be reported back to
        the calling process. This value should be zero if everything works OK,
        otherwise, different than zero (POSIX compliance).

    """

    status = 0
533
534
535
536
    RESULTS_SIMPLE_TYPE_NAMES = ("int32", "float32", "bool", "string")

    if indentation == 0:
        indentation = 4
537

538
539
540
541
542
543
544
545
546
    if remote_results:
        if outputfolder is None:
            output_folder = configuration.path
        else:
            # check if directory exists else create
            if not os.path.isdir(outputfolder):
                os.mkdir(os.path.join(configuration.path, outputfolder))
            output_folder = os.path.join(configuration.path, outputfolder)

547
    for name in names:
548
549
        if not remote_results:
            if outputfolder is None:
550
551
552
553
554
                output_folder = os.path.join(
                    configuration.path,
                    common.TYPE_PLURAL["experiment"],
                    name.rsplit("/", 1)[0],
                )
555
556
557
558
559
560
            else:
                # check if directory exists else create
                if not os.path.isdir(outputfolder):
                    os.mkdir(os.path.join(configuration.path, outputfolder))
                output_folder = os.path.join(configuration.path, outputfolder)

561
        check_plottable = False
562
        if not os.path.exists(configuration.cache) or remote_results:
563
564
565
566
567
568
            experiment = simplejson.loads(
                simplejson.dumps(
                    common.fetch_object(webapi, "experiment", name, ["results"])
                )
            )
            results = experiment["results"]["analysis"]
569
570
            for key, value in results.iteritems():
                # remove non plottable results
571
572
                if value["type"] not in RESULTS_SIMPLE_TYPE_NAMES:
                    output_name = name.rsplit("/", 1)[1] + "_" + key + ".png"
573
                    output_name = os.path.join(output_folder, output_name)
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
                    pl_status = plotters_pull(
                        webapi,
                        configuration.path,
                        [value["type"]],
                        force,
                        indentation,
                        {},
                    )
                    plot_status = plotters_plot(
                        webapi,
                        configuration.path,
                        [value["type"]],
                        show,
                        False,
                        False,
                        value["value"],
                        output_name,
                        None,
                        indentation,
                        format_cache,
                    )
595
596
                    status += pl_status
                    status += plot_status
597
                    check_plottable = True
598
        else:
599
            # make sure experiment exists locally or pull it
600
601
602
            pull_impl(
                webapi, configuration.path, [name], force, indentation, format_cache
            )
603
604
605
606
607
608
609

            # get information from cache
            dataformat_cache = {}
            database_cache = {}
            algorithm_cache = {}
            library_cache = {}

610
611
612
613
614
615
616
617
            experiment = Experiment(
                configuration.path,
                name,
                dataformat_cache,
                database_cache,
                algorithm_cache,
                library_cache,
            )
618
619
620

            scheduled = experiment.setup()
            for key, value in scheduled.items():
621
622
623
624
625
626
627
628
629
630
631
632
                executor = LocalExecutor(
                    configuration.path,
                    value["configuration"],
                    configuration.cache,
                    dataformat_cache,
                    database_cache,
                    algorithm_cache,
                    library_cache,
                    configuration.database_paths,
                )

                if "result" in executor.data:
633
                    f = CachedDataSource()
634
635
636
637
638
639
640
641
642
                    success = f.setup(
                        os.path.join(
                            executor.cache, executor.data["result"]["path"] + ".data"
                        ),
                        executor.prefix,
                    )
                    if not success:
                        raise RuntimeError("Failed to setup cached data source")

643
644
645
646
                    data, start, end = f[0]

                    for the_data in data.as_dict():
                        attr = getattr(data, the_data)
647
648
                        if attr.__class__.__name__.startswith("plot"):
                            datatype = attr.__class__.__name__.replace("_", "/")
649
650
                            # remove non plottable results
                            if datatype not in RESULTS_SIMPLE_TYPE_NAMES:
651
652
653
                                output_name = (
                                    name.rsplit("/", 1)[1] + "_" + the_data + ".png"
                                )
654
                                output_name = os.path.join(output_folder, output_name)
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
                                pl_status = plotters_pull(
                                    webapi,
                                    configuration.path,
                                    [datatype],
                                    force,
                                    indentation,
                                    {},
                                )
                                plot_status = plotters_plot(
                                    webapi,
                                    configuration.path,
                                    [datatype],
                                    show,
                                    False,
                                    False,
                                    data.as_dict()[the_data],
                                    output_name,
                                    None,
                                    indentation,
                                    format_cache,
                                )
676
677
                                status += pl_status
                                status += plot_status
678
679
                                check_plottable = True
        if not check_plottable:
680
            print("Experiments results are not plottable")
681
682
683
684

    return status


685
@click.group(cls=AliasedGroup)
686
687
688
@click.pass_context
def experiments(ctx):
    """experiments commands"""
689
690
691
692
693

    ctx.meta["asset_type"] = "experiment"


experiments.command(name="list")(commands.command("list"))
694
695
696


@experiments.command()
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
@click.argument("name", nargs=1)
@click.option(
    "--force", help="Performs operation regardless of conflicts", is_flag=True
)
@click.option(
    "--docker",
    help="Uses the docker executor to execute the "
    "experiment using docker containers",
    is_flag=True,
)
@click.option(
    "--local",
    help="Uses the local executor to execute the "
    "experiment on the local machine (default)",
    default=True,
    is_flag=True,
)
@click.option("--quiet", help="Be less verbose", is_flag=True)
715
@click.pass_context
716
@raise_on_error
717
def run(ctx, name, force, docker, local, quiet):
718
719
    """ Runs an experiment locally"""
    config = ctx.meta.get("config")
720
    return run_experiment(config, name, force, docker, local, quiet)
721
722
723


@experiments.command()
724
725
726
727
728
729
730
731
732
733
734
@click.argument("name", nargs=1)
@click.option(
    "--list", help="List cache files matching output if they exist", is_flag=True
)
@click.option(
    "--delete",
    help="Delete cache files matching output if they "
    "exist (also, recursively deletes empty directories)",
    is_flag=True,
)
@click.option("--checksum", help="Checksums indexes for cache files", is_flag=True)
735
@click.pass_context
736
@raise_on_error
737
def caches(ctx, name, list, delete, checksum):
738
739
    """Lists all cache files used by this experiment"""
    config = ctx.meta.get("config")
740
    return caches_impl(config, name, list, delete, checksum)
741
742
743


@experiments.command()
744
@click.argument("names", nargs=-1)
745
@click.pass_context
746
@raise_on_error
747
def path(ctx, names):
748
    """Displays local path of experiment files
749
750
751

  Example:
    $ beat experiments path xxx
752
753
  """
    return common.display_local_path(ctx.meta["config"].path, "experiment", names)
754
755
756


@experiments.command()
757
@click.argument("name", nargs=1)
758
@click.pass_context
759
@raise_on_error
760
def edit(ctx, name):
761
    """Edit local experiment file
762
763
764

  Example:
    $ beat experiments edit xxx
765
766
767
768
  """
    return common.edit_local_file(
        ctx.meta["config"].path, ctx.meta["config"].editor, "experiment", name
    )
769
770
771


@experiments.command()
772
@click.argument("names", nargs=-1)
773
@click.pass_context
774
@raise_on_error
775
def check(ctx, names):
776
    """Checks a local experiment for validity.
777
778

    $ beat experiments check xxx
779
780
781
    """
    config = ctx.meta.get("config")
    return common.check(config.path, "experiment", names)
782
783
784


@experiments.command()
785
786
787
788
@click.argument("names", nargs=-1)
@click.option(
    "--force", help="Performs operation regardless of conflicts", is_flag=True
)
789
@click.pass_context
790
@raise_on_error
791
def pull(ctx, names, force):
792
    """Downloads the specified experiments from the server.
793
794

       $ beat experiments pull xxx.
795
796
    """
    config = ctx.meta.get("config")
797
    with common.make_webapi(config) as webapi:
798
        return pull_impl(webapi, config.path, names, force, 0, {})
799
800
801


@experiments.command()
802
803
804
805
806
807
808
809
810
@click.argument("names", nargs=-1)
@click.option(
    "--force", help="Performs operation regardless of conflicts", is_flag=True
)
@click.option(
    "--dry-run",
    help="Doesn't really perform the task, just " "comments what would do",
    is_flag=True,
)
811
@click.pass_context
812
@raise_on_error
813
def push(ctx, names, force, dry_run):
814
    """Uploads experiments to the server.
815
816
817

    Example:
      $ beat experiments push --dry-run yyy
818
819
    """
    config = ctx.meta.get("config")
820
821
    with common.make_webapi(config) as webapi:
        return common.push(
822
823
824
825
826
827
828
829
830
            webapi,
            config.path,
            "experiment",
            names,
            ["name", "declaration", "toolchain", "description"],
            {},
            force,
            dry_run,
            0,
831
832
833
834
        )


@experiments.command()
835
@click.argument("name", nargs=1)
836
@click.pass_context
837
@raise_on_error
838
def diff(ctx, name):
839
    """Shows changes between the local dataformat and the remote version.
840
841
842

    Example:
      $ beat experiments diff xxx
843
844
    """
    config = ctx.meta.get("config")
845
846
    with common.make_webapi(config) as webapi:
        return common.diff(
847
            webapi, config.path, "experiment", name, ["declaration", "description"]
848
849
850
851
852
        )


@experiments.command()
@click.pass_context
853
@raise_on_error
854
def status(ctx):
855
    """Shows (editing) status for all available experiments.
856
857
858

    Example:
      $ beat experiments status
859
860
    """
    config = ctx.meta.get("config")
861
    with common.make_webapi(config) as webapi:
862
        return common.status(webapi, config.path, "experiment")[0]
863
864
865


@experiments.command()
866
867
@click.argument("src", nargs=1)
@click.argument("dst", nargs=1)
868
@click.pass_context
869
@raise_on_error
870
def fork(ctx, src, dst):
871
    """Forks a local experiment.
872
873

    $ beat experiments fork xxx yyy
874
875
876
    """
    config = ctx.meta.get("config")
    return common.fork(config.path, "experiment", src, dst)
877
878
879


@experiments.command()
880
881
882
883
@click.argument("names", nargs=-1)
@click.option(
    "--remote", help="Only acts on the remote copy of the experiment", is_flag=True
)
884
@click.pass_context
885
@raise_on_error
886
def rm(ctx, names, remote):
887
    """Deletes a local experiment (unless --remote is specified).
888
889

    $ beat experiments rm xxx
890
891
    """
    config = ctx.meta.get("config")
892
    if remote:
893
        with common.make_webapi(config) as webapi:
894
            return common.delete_remote(webapi, "experiment", names)
895
    else:
896
        return common.delete_local(config.path, "experiment", names)
897
898
899


@experiments.command()
900
901
902
903
904
905
@click.argument("names", nargs=-1)
@click.option(
    "--path",
    help="Use path to write files to disk (instead of the " "current directory)",
    type=click.Path(),
)
906
@click.pass_context
907
@raise_on_error
908
def draw(ctx, names, path):
909
910
911
    """Creates a visual representation of the experiment."""
    config = ctx.meta.get("config")
    return common.dot_diagram(config.path, "experiment", names, path, [])
912
913
914


@experiments.command()
915
916
917
918
919
920
921
922
923
@click.argument("names", nargs=-1)
@click.option(
    "--force", help="Performs operation regardless of conflicts", is_flag=True
)
@click.option(
    "--remote", help="Only acts on the remote copy of the experiment", is_flag=True
)
@click.option("--show", help="Show...", is_flag=True)
@click.option("--output-folder", help="<folder>", type=click.Path(exists=True))
924
@click.pass_context
925
@raise_on_error
926
def plot(ctx, names, force, remote, show, output_folder):
927
928
    """Plots output images of the experiment."""
    config = ctx.meta.get("config")
929
930
    with common.make_webapi(config) as webapi:
        return plot_impl(
931
932
933
934
935
936
937
938
939
940
            webapi,
            config,
            "experiment",
            names,
            remote,
            show,
            force,
            0,
            {},
            output_folder,
941
        )