experiments.py 37.8 KB
Newer Older
André Anjos's avatar
André Anjos committed
1
2
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
André Anjos's avatar
André Anjos committed
3

Samuel GAIST's avatar
Samuel GAIST committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
###################################################################################
#                                                                                 #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/               #
# Contact: beat.support@idiap.ch                                                  #
#                                                                                 #
# Redistribution and use in source and binary forms, with or without              #
# modification, are permitted provided that the following conditions are met:     #
#                                                                                 #
# 1. Redistributions of source code must retain the above copyright notice, this  #
# list of conditions and the following disclaimer.                                #
#                                                                                 #
# 2. Redistributions in binary form must reproduce the above copyright notice,    #
# this list of conditions and the following disclaimer in the documentation       #
# and/or other materials provided with the distribution.                          #
#                                                                                 #
# 3. Neither the name of the copyright holder nor the names of its contributors   #
# may be used to endorse or promote products derived from this software without   #
# specific prior written permission.                                              #
#                                                                                 #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED   #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE          #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE    #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL      #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR      #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER      #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,   #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE   #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.            #
#                                                                                 #
###################################################################################
André Anjos's avatar
André Anjos committed
35
36
37
38
39


import os
import logging
import glob
40
import click
André Anjos's avatar
André Anjos committed
41
import simplejson
42
import six
43
44
45
46
47
48
49
import curses
import textwrap
import threading
import queue
import signal

from datetime import timedelta
André Anjos's avatar
André Anjos committed
50
51

from beat.core.experiment import Experiment
52
53
from beat.core.execution import DockerExecutor
from beat.core.execution import LocalExecutor
André Anjos's avatar
André Anjos committed
54
55
from beat.core.utils import NumpyJSONEncoder
from beat.core.data import CachedDataSource, load_data_index
56
from beat.core.dock import Host
57
58
59
from beat.core.hash import toPath
from beat.core.hash import hashDataset

60
from . import common
61
62
from . import commands

63
64
from .plotters import plot_impl as plotters_plot
from .plotters import pull_impl as plotters_pull
65
from .decorators import raise_on_error
66
67
from .click_helper import AliasedGroup

68
69

logger = logging.getLogger(__name__)
André Anjos's avatar
André Anjos committed
70
71


72
def run_experiment(configuration, name, force, use_docker, use_local, quiet):
73
    """Run experiments locally"""
74
75

    def load_result(executor):
76
        """Loads the result of an experiment, in a single go"""
77
78

        f = CachedDataSource()
79
80
81
82
83
84
85
86
        success = f.setup(
            os.path.join(executor.cache, executor.data["result"]["path"] + ".data"),
            executor.prefix,
        )

        if not success:
            raise RuntimeError("Failed to setup cached data source")

87
88
89
        data, start, end = f[0]
        return data

90
91
    def print_results(executor):
        data = load_result(executor)
92
93
94
        r = reindent(
            simplejson.dumps(data.as_dict(), indent=2, cls=NumpyJSONEncoder), 2
        )
95
96
        logger.info("  Results:\n%s", r)

97
    def reindent(s, n):
98
99
100
        """Re-indents output so it is more visible"""
        margin = n * " "
        return margin + ("\n" + margin).join(s.split("\n"))
101
102

    def simplify_time(s):
103
        """Re-writes the time so it is easier to understand it"""
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

        minute = 60.0
        hour = 60 * minute
        day = 24 * hour

        if s <= minute:
            return "%.2f s" % s
        elif s <= hour:
            minutes = s // minute
            seconds = s - (minute * minutes)
            return "%d m %.2f s" % (minutes, seconds)
        elif s <= day:
            hours = s // hour
            minutes = (s - (hour * hours)) // minute
            seconds = s - (hour * hours + minute * minutes)
            return "%d h %d m %.2f s" % (hours, minutes, seconds)
        else:
            days = s // day
            hours = (s - (day * days)) // hour
            minutes = (s - (day * days + hour * hours)) // minute
            seconds = s - (day * days + hour * hours + minute * minutes)
            return "%d days %d h %d m %.2f s" % (days, hours, minutes, seconds)

    def simplify_size(s):
128
        """Re-writes the size so it is easier to understand it"""
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

        kb = 1024.0
        mb = kb * kb
        gb = kb * mb
        tb = kb * gb

        if s <= kb:
            return "%d bytes" % s
        elif s <= mb:
            return "%.2f kilobytes" % (s / kb)
        elif s <= gb:
            return "%.2f megabytes" % (s / mb)
        elif s <= tb:
            return "%.2f gigabytes" % (s / gb)
        return "%.2f terabytes" % (s / tb)

145
    def index_experiment_databases(cache_path, experiment):
146
        for block_name, infos in experiment.datasets.items():
147
148
149
150
151
            view = infos["database"].view(infos["protocol"], infos["set"])
            filename = toPath(
                hashDataset(infos["database"].name, infos["protocol"], infos["set"]),
                suffix=".db",
            )
152
153
            database_index_path = os.path.join(cache_path, filename)
            if not os.path.exists(database_index_path):
154
155
156
157
                logger.info(
                    "Index for database %s not found, building it",
                    infos["database"].name,
                )
158
                view.index(database_index_path)
159
160
161
162
163
164

    dataformat_cache = {}
    database_cache = {}
    algorithm_cache = {}
    library_cache = {}

165
166
167
168
169
170
171
172
    experiment = Experiment(
        configuration.path,
        name,
        dataformat_cache,
        database_cache,
        algorithm_cache,
        library_cache,
    )
173
174
175
176

    if not experiment.valid:
        logger.error("Failed to load the experiment `%s':", name)
        for e in experiment.errors:
177
            logger.error("  * %s", e)
178
        return 1
André Anjos's avatar
André Anjos committed
179

180
181
182
    if not os.path.exists(configuration.cache):
        os.makedirs(configuration.cache)
        logger.info("Created cache path `%s'", configuration.cache)
183

184
    index_experiment_databases(configuration.cache, experiment)
André Anjos's avatar
André Anjos committed
185

186
    scheduled = experiment.setup()
André Anjos's avatar
André Anjos committed
187

188
    if use_docker:
189
190
191
192
193
194
195
196
197
198
        # load existing environments
        host = Host(raise_on_errors=False)

    # can we execute it?
    for key, value in scheduled.items():

        # checks and sets-up executable
        executable = None  # use the default

        if use_docker:
199
200
            env = value["configuration"]["environment"]
            search_key = "%s (%s)" % (env["name"], env["version"])
201
            if search_key not in host:
202
203
204
205
206
207
                logger.error(
                    "Cannot execute block `%s' on environment `%s': "
                    "environment was not found' - please install it",
                    key,
                    search_key,
                )
208
209
210
                return 1

        if use_docker:
211
212
213
214
215
216
217
218
219
220
            executor = DockerExecutor(
                host,
                configuration.path,
                value["configuration"],
                configuration.cache,
                dataformat_cache,
                database_cache,
                algorithm_cache,
                library_cache,
            )
221
        else:
222
223
224
225
226
227
228
229
230
231
            executor = LocalExecutor(
                configuration.path,
                value["configuration"],
                configuration.cache,
                dataformat_cache,
                database_cache,
                algorithm_cache,
                library_cache,
                configuration.database_paths,
            )
232
233

        if not executor.valid:
234
            logger.error("Failed to load the execution information for `%s':", key)
235
            for e in executor.errors:
236
                logger.error("  * %s", e)
237
238
239
            return 1

        if executor.outputs_exist and not force:
240
241
242
243
244
            logger.info(
                "Skipping execution of `%s' for block `%s' " "- outputs exist",
                executor.algorithm.name,
                key,
            )
245
            if executor.analysis and not quiet:
246
247
                logger.extra("  Outputs produced:")
                print_results(executor)
248
249
            continue

250
        logger.info("Running `%s' for block `%s'", executor.algorithm.name, key)
251
252
253
254
255
256
257
258
        if executable is not None:
            logger.extra("  -> using executable at `%s'", executable)
        else:
            logger.extra("  -> using fallback (default) environment")

        with executor:
            result = executor.process()

259
        if result["status"] != 0:
260
            logger.error("Block did not execute properly - outputs were reset")
261
262
263
264
265
266
267
268
269
            logger.error("  Standard output:\n%s", reindent(result["stdout"], 4))
            logger.error("  Standard error:\n%s", reindent(result["stderr"], 4))
            logger.error(
                "  Captured user error:\n%s", reindent(result["user_error"], 4)
            )
            logger.error(
                "  Captured system error:\n%s", reindent(result["system_error"], 4)
            )
            logger.extra("  Environment: %s" % "default environment")
270
271
            return 1
        elif use_docker:
272
273
274
            stats = result["statistics"]
            cpu_stats = stats["cpu"]
            data_stats = stats["data"]
275

276
            cpu_total = cpu_stats["total"]
277
278
279
280
            # Likely means that GPU was used
            if not cpu_total:
                cpu_total = 1.0

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
            logger.extra(
                "  CPU time (user, system, total, percent): " "%s, %s, %s, %d%%",
                simplify_time(cpu_stats["user"]),
                simplify_time(cpu_stats["system"]),
                simplify_time(cpu_total),
                100.0 * (cpu_stats["user"] + cpu_stats["system"]) / cpu_total,
            )
            logger.extra("  Memory usage: %s", simplify_size(stats["memory"]["rss"]))
            logger.extra(
                "  Cached input read: %s, %s",
                simplify_time(data_stats["time"]["read"]),
                simplify_size(data_stats["volume"]["read"]),
            )
            logger.extra(
                "  Cached output write: %s, %s",
                simplify_time(data_stats["time"]["write"]),
                simplify_size(data_stats["volume"]["write"]),
            )
            logger.extra(
                "  Communication time: %s (%d%%)",
                simplify_time(data_stats["network"]["wait_time"]),
                100.0 * data_stats["network"]["wait_time"] / cpu_total,
            )
304
        else:
305
            logger.extra("  Environment: %s" % "local environment")
306

307
308
309
        if not quiet:
            if executor.analysis:
                print_results(executor)
310

311
312
            logger.extra("  Outputs produced:")
            if executor.analysis:
313
                logger.extra("    * %s", executor.data["result"]["path"])
314
            else:
315
316
                for name, details in executor.data["outputs"].items():
                    logger.extra("    * %s", details["path"])
317
        else:
318
            logger.info("Done")
319
320

    return 0
André Anjos's avatar
André Anjos committed
321
322


323
def caches_impl(configuration, name, ls, delete, checksum):
324
    """List all cache files involved in this experiment"""
André Anjos's avatar
André Anjos committed
325

326
327
328
329
    dataformat_cache = {}
    database_cache = {}
    algorithm_cache = {}
    library_cache = {}
André Anjos's avatar
André Anjos committed
330

331
332
333
334
335
336
337
338
    experiment = Experiment(
        configuration.path,
        name,
        dataformat_cache,
        database_cache,
        algorithm_cache,
        library_cache,
    )
André Anjos's avatar
André Anjos committed
339

340
341
342
    if not experiment.valid:
        logger.error("Failed to load the experiment `%s':", name)
        for e in experiment.errors:
343
            logger.error("  * %s", e)
344
        return 1
André Anjos's avatar
André Anjos committed
345

346
    scheduled = experiment.setup()
André Anjos's avatar
André Anjos committed
347

348
    block_list = []
349
    for key, value in scheduled.items():
350
        block = {
351
352
353
354
            "name": key,
            "algorithm": value["configuration"]["algorithm"],
            "is_analyser": False,
            "paths": [],
355
356
        }

357
358
359
        if "outputs" in value["configuration"]:  # normal block
            for name, data in value["configuration"]["outputs"].items():
                block["paths"].append(data["path"])
360
        else:  # analyzer
361
362
            block["is_analyser"] = True
            block["paths"].append(value["configuration"]["result"]["path"])
André Anjos's avatar
André Anjos committed
363

364
        block_list.append(block)
André Anjos's avatar
André Anjos committed
365

366
    for block in block_list:
367
368
369
        block_type = "analyzer" if block["is_analyser"] else "algorithm"
        logger.info("block: `%s'", block["name"])
        logger.info("  %s: `%s'", block_type, block["algorithm"])
André Anjos's avatar
André Anjos committed
370

371
        for path in block["paths"]:
372
373
374
            # prefix cache path
            path = os.path.join(configuration.cache, path)
            logger.info("  output: `%s'", path)
André Anjos's avatar
André Anjos committed
375

376
            if ls:
377
378
                for file in glob.glob(path + ".*"):
                    logger.info("    %s" % file)
André Anjos's avatar
André Anjos committed
379

380
            if delete:
381
                for file in glob.glob(path + ".*"):
382
383
                    logger.info("removing `%s'...", file)
                    os.unlink(file)
André Anjos's avatar
André Anjos committed
384

385
                common.recursive_rmdir_if_empty(
386
387
                    os.path.dirname(path), configuration.cache
                )
André Anjos's avatar
André Anjos committed
388

389
            if checksum:
390
391
                if not load_data_index(configuration.cache, path + ".data"):
                    logger.error("Failed to load data index for {}".format(path))
392
                logger.info("index for `%s' can be loaded and checksums", path)
393

394
    return 0
André Anjos's avatar
André Anjos committed
395
396


397
def pull_impl(webapi, prefix, names, force, indentation, format_cache):
398
    """Copies experiments (and required toolchains/algorithms) from the server.
André Anjos's avatar
André Anjos committed
399

400
    Parameters:
André Anjos's avatar
André Anjos committed
401

402
403
      webapi (object): An instance of our WebAPI class, prepared to access the
        BEAT server of interest
André Anjos's avatar
André Anjos committed
404

405
406
      prefix (str): A string representing the root of the path in which the
        user objects are stored
André Anjos's avatar
André Anjos committed
407

André Anjos's avatar
André Anjos committed
408
409
410
411
412
      names (:py:class:`list`): A list of strings, each representing the unique
        relative path of the objects to retrieve or a list of usernames from
        which to retrieve objects. If the list is empty, then we pull all
        available objects of a given type. If no user is set, then pull all
        public objects of a given type.
André Anjos's avatar
André Anjos committed
413

414
415
      force (bool): If set to ``True``, then overwrites local changes with the
        remotely retrieved copies.
André Anjos's avatar
André Anjos committed
416

417
418
419
      indentation (int): The indentation level, useful if this function is
        called recursively while downloading different object types. This is
        normally set to ``0`` (zero).
André Anjos's avatar
André Anjos committed
420
421


422
    Returns:
André Anjos's avatar
André Anjos committed
423

424
425
426
      int: Indicating the exit status of the command, to be reported back to
        the calling process. This value should be zero if everything works OK,
        otherwise, different than zero (POSIX compliance).
André Anjos's avatar
André Anjos committed
427

428
    """
André Anjos's avatar
André Anjos committed
429

430
431
    from .algorithms import pull_impl as algorithms_pull
    from .databases import pull_impl as databases_pull
André Anjos's avatar
André Anjos committed
432

433
434
435
436
437
438
439
440
441
442
443
444
    if indentation == 0:
        indentation = 4

    status, names = common.pull(
        webapi,
        prefix,
        "experiment",
        names,
        ["declaration", "description"],
        force,
        indentation,
    )
André Anjos's avatar
André Anjos committed
445

446
    if status != 0:
447
        logger.error("could not find any matching experiments - widen your search")
448
        return status
André Anjos's avatar
André Anjos committed
449

450
    # see what dataformats one needs to pull
451
452
453
    databases = set()
    toolchains = set()
    algorithms = set()
454
455
456
457
458
459
460
    for name in names:
        try:
            obj = Experiment(prefix, name)
            if obj.toolchain:
                toolchains.add(obj.toolchain.name)
            databases |= obj.databases.keys()
            algorithms |= obj.algorithms.keys()
André Anjos's avatar
André Anjos committed
461

462
463
        except Exception as e:
            logger.error("loading `%s': %s...", name, str(e))
André Anjos's avatar
André Anjos committed
464

465
466
467
    # downloads any formats to which we depend on
    format_cache = {}
    library_cache = {}
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    tc_status, _ = common.pull(
        webapi,
        prefix,
        "toolchain",
        toolchains,
        ["declaration", "description"],
        force,
        indentation,
    )
    db_status = databases_pull(
        webapi, prefix, databases, force, indentation, format_cache
    )
    algo_status = algorithms_pull(
        webapi, prefix, algorithms, force, indentation, format_cache, library_cache
    )
André Anjos's avatar
André Anjos committed
483

484
    return status + tc_status + db_status + algo_status
André Anjos's avatar
André Anjos committed
485
486


487
488
489
490
491
492
493
494
495
496
497
498
def plot_impl(
    webapi,
    configuration,
    prefix,
    names,
    remote_results,
    show,
    force,
    indentation,
    format_cache,
    outputfolder=None,
):
499
500
501
502
503
504
505
    """Plots experiments from the server.

    Parameters:

      webapi (object): An instance of our WebAPI class, prepared to access the
        BEAT server of interest

506
507
508
      configuration (object): An instance of the configuration, to access the
        BEAT server and current configuration for information

509
510
511
      prefix (str): A string representing the root of the path in which the
        user objects are stored

512
      names (:py:class:`list`): A list of strings, each representing the unique relative
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
        path of the objects to retrieve or a list of usernames from which to
        retrieve objects. If the list is empty, then we pull all available
        objects of a given type. If no user is set, then pull all public
        objects of a given type.

      remote_results(bool): If set to ``True``, then fetch results data
        for the experiments from the server.

      force (bool): If set to ``True``, then overwrites local changes with the
        remotely retrieved copies.

      indentation (int): The indentation level, useful if this function is
        called recursively while downloading different object types. This is
        normally set to ``0`` (zero).

528
529
      outputfolder (str): A string representing the path in which the
        experiments plot will be stored
530
531
532
533
534
535
536
537
538
539

    Returns:

      int: Indicating the exit status of the command, to be reported back to
        the calling process. This value should be zero if everything works OK,
        otherwise, different than zero (POSIX compliance).

    """

    status = 0
540
541
542
543
    RESULTS_SIMPLE_TYPE_NAMES = ("int32", "float32", "bool", "string")

    if indentation == 0:
        indentation = 4
544

545
546
547
548
549
550
551
552
553
    if remote_results:
        if outputfolder is None:
            output_folder = configuration.path
        else:
            # check if directory exists else create
            if not os.path.isdir(outputfolder):
                os.mkdir(os.path.join(configuration.path, outputfolder))
            output_folder = os.path.join(configuration.path, outputfolder)

554
    for name in names:
555
556
        if not remote_results:
            if outputfolder is None:
557
558
559
560
561
                output_folder = os.path.join(
                    configuration.path,
                    common.TYPE_PLURAL["experiment"],
                    name.rsplit("/", 1)[0],
                )
562
563
564
565
566
567
            else:
                # check if directory exists else create
                if not os.path.isdir(outputfolder):
                    os.mkdir(os.path.join(configuration.path, outputfolder))
                output_folder = os.path.join(configuration.path, outputfolder)

568
        check_plottable = False
569
        if not os.path.exists(configuration.cache) or remote_results:
570
571
572
573
574
575
            experiment = simplejson.loads(
                simplejson.dumps(
                    common.fetch_object(webapi, "experiment", name, ["results"])
                )
            )
            results = experiment["results"]["analysis"]
576
577
            for key, value in results.iteritems():
                # remove non plottable results
578
579
                if value["type"] not in RESULTS_SIMPLE_TYPE_NAMES:
                    output_name = name.rsplit("/", 1)[1] + "_" + key + ".png"
580
                    output_name = os.path.join(output_folder, output_name)
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
                    pl_status = plotters_pull(
                        webapi,
                        configuration.path,
                        [value["type"]],
                        force,
                        indentation,
                        {},
                    )
                    plot_status = plotters_plot(
                        webapi,
                        configuration.path,
                        [value["type"]],
                        show,
                        False,
                        False,
                        value["value"],
                        output_name,
                        None,
                        indentation,
                        format_cache,
                    )
602
603
                    status += pl_status
                    status += plot_status
604
                    check_plottable = True
605
        else:
606
            # make sure experiment exists locally or pull it
607
608
609
            pull_impl(
                webapi, configuration.path, [name], force, indentation, format_cache
            )
610
611
612
613
614
615
616

            # get information from cache
            dataformat_cache = {}
            database_cache = {}
            algorithm_cache = {}
            library_cache = {}

617
618
619
620
621
622
623
624
            experiment = Experiment(
                configuration.path,
                name,
                dataformat_cache,
                database_cache,
                algorithm_cache,
                library_cache,
            )
625
626
627

            scheduled = experiment.setup()
            for key, value in scheduled.items():
628
629
630
631
632
633
634
635
636
637
638
639
                executor = LocalExecutor(
                    configuration.path,
                    value["configuration"],
                    configuration.cache,
                    dataformat_cache,
                    database_cache,
                    algorithm_cache,
                    library_cache,
                    configuration.database_paths,
                )

                if "result" in executor.data:
640
                    f = CachedDataSource()
641
642
643
644
645
646
647
648
649
                    success = f.setup(
                        os.path.join(
                            executor.cache, executor.data["result"]["path"] + ".data"
                        ),
                        executor.prefix,
                    )
                    if not success:
                        raise RuntimeError("Failed to setup cached data source")

650
651
652
653
                    data, start, end = f[0]

                    for the_data in data.as_dict():
                        attr = getattr(data, the_data)
654
655
                        if attr.__class__.__name__.startswith("plot"):
                            datatype = attr.__class__.__name__.replace("_", "/")
656
657
                            # remove non plottable results
                            if datatype not in RESULTS_SIMPLE_TYPE_NAMES:
658
659
660
                                output_name = (
                                    name.rsplit("/", 1)[1] + "_" + the_data + ".png"
                                )
661
                                output_name = os.path.join(output_folder, output_name)
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
                                pl_status = plotters_pull(
                                    webapi,
                                    configuration.path,
                                    [datatype],
                                    force,
                                    indentation,
                                    {},
                                )
                                plot_status = plotters_plot(
                                    webapi,
                                    configuration.path,
                                    [datatype],
                                    show,
                                    False,
                                    False,
                                    data.as_dict()[the_data],
                                    output_name,
                                    None,
                                    indentation,
                                    format_cache,
                                )
683
684
                                status += pl_status
                                status += plot_status
685
686
                                check_plottable = True
        if not check_plottable:
687
            print("Experiments results are not plottable")
688
689
690
691

    return status


692
@click.group(cls=AliasedGroup)
693
694
695
@click.pass_context
def experiments(ctx):
    """experiments commands"""
696
697

    ctx.meta["asset_type"] = "experiment"
698
    ctx.meta["diff_fields"] = ["declaration", "description"]
699
700


701
702
703
CMD_LIST = ["list", "path", "edit", "check", "status", "fork", "rm", "diff"]

commands.initialise_asset_commands(experiments, CMD_LIST)
704
705
706


@experiments.command()
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
@click.argument("name", nargs=1)
@click.option(
    "--force", help="Performs operation regardless of conflicts", is_flag=True
)
@click.option(
    "--docker",
    help="Uses the docker executor to execute the "
    "experiment using docker containers",
    is_flag=True,
)
@click.option(
    "--local",
    help="Uses the local executor to execute the "
    "experiment on the local machine (default)",
    default=True,
    is_flag=True,
)
@click.option("--quiet", help="Be less verbose", is_flag=True)
725
@click.pass_context
726
@raise_on_error
727
def run(ctx, name, force, docker, local, quiet):
728
729
    """ Runs an experiment locally"""
    config = ctx.meta.get("config")
730
    return run_experiment(config, name, force, docker, local, quiet)
731
732
733


@experiments.command()
734
735
736
737
738
739
740
741
742
743
744
@click.argument("name", nargs=1)
@click.option(
    "--list", help="List cache files matching output if they exist", is_flag=True
)
@click.option(
    "--delete",
    help="Delete cache files matching output if they "
    "exist (also, recursively deletes empty directories)",
    is_flag=True,
)
@click.option("--checksum", help="Checksums indexes for cache files", is_flag=True)
745
@click.pass_context
746
@raise_on_error
747
def caches(ctx, name, list, delete, checksum):
748
749
    """Lists all cache files used by this experiment"""
    config = ctx.meta.get("config")
750
    return caches_impl(config, name, list, delete, checksum)
751
752
753


@experiments.command()
754
755
756
757
@click.argument("names", nargs=-1)
@click.option(
    "--force", help="Performs operation regardless of conflicts", is_flag=True
)
758
@click.pass_context
759
@raise_on_error
760
def pull(ctx, names, force):
761
    """Downloads the specified experiments from the server.
762
763

       $ beat experiments pull xxx.
764
765
    """
    config = ctx.meta.get("config")
766
    with common.make_webapi(config) as webapi:
767
        return pull_impl(webapi, config.path, names, force, 0, {})
768
769
770


@experiments.command()
771
772
773
774
775
776
777
778
779
@click.argument("names", nargs=-1)
@click.option(
    "--force", help="Performs operation regardless of conflicts", is_flag=True
)
@click.option(
    "--dry-run",
    help="Doesn't really perform the task, just " "comments what would do",
    is_flag=True,
)
780
@click.pass_context
781
@raise_on_error
782
def push(ctx, names, force, dry_run):
783
    """Uploads experiments to the server.
784
785
786

    Example:
      $ beat experiments push --dry-run yyy
787
788
    """
    config = ctx.meta.get("config")
789
790
    with common.make_webapi(config) as webapi:
        return common.push(
791
792
793
794
795
796
797
798
799
            webapi,
            config.path,
            "experiment",
            names,
            ["name", "declaration", "toolchain", "description"],
            {},
            force,
            dry_run,
            0,
800
801
802
803
        )


@experiments.command()
804
805
806
807
808
809
@click.argument("names", nargs=-1)
@click.option(
    "--path",
    help="Use path to write files to disk (instead of the " "current directory)",
    type=click.Path(),
)
810
@click.pass_context
811
@raise_on_error
812
def draw(ctx, names, path):
813
814
815
    """Creates a visual representation of the experiment."""
    config = ctx.meta.get("config")
    return common.dot_diagram(config.path, "experiment", names, path, [])
816
817
818


@experiments.command()
819
820
821
822
823
824
825
826
827
@click.argument("names", nargs=-1)
@click.option(
    "--force", help="Performs operation regardless of conflicts", is_flag=True
)
@click.option(
    "--remote", help="Only acts on the remote copy of the experiment", is_flag=True
)
@click.option("--show", help="Show...", is_flag=True)
@click.option("--output-folder", help="<folder>", type=click.Path(exists=True))
828
@click.pass_context
829
@raise_on_error
830
def plot(ctx, names, force, remote, show, output_folder):
831
832
    """Plots output images of the experiment."""
    config = ctx.meta.get("config")
833
834
    with common.make_webapi(config) as webapi:
        return plot_impl(
835
836
837
838
839
840
841
842
843
844
            webapi,
            config,
            "experiment",
            names,
            remote,
            show,
            force,
            0,
            {},
            output_folder,
845
        )
846
847
848
849


@experiments.command()
@click.argument("name", nargs=1)
850
@click.option("--watch", help="Start monitoring the execution", is_flag=True)
851
@click.pass_context
852
def start(ctx, name, watch):
853
854
855
856
857
858
859
860
861
862
863
864
    """Start an experiment on the platform"""

    config = ctx.meta.get("config")
    with common.make_webapi(config) as webapi:
        status, _ = webapi.post("/api/v1/experiments/{}/start/".format(name))

    if status != six.moves.http_client.OK:
        logger.error(
            "failed to start {} on `{}', reason: {}".format(
                name, webapi.platform, six.moves.http_client.responses[status]
            )
        )
865
866
    elif watch:
        ctx.invoke(monitor, name=name)
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890


@experiments.command()
@click.argument("name", nargs=1)
@click.pass_context
def cancel(ctx, name):
    """Cancel an experiment on the platform"""

    config = ctx.meta.get("config")
    with common.make_webapi(config) as webapi:
        status, _ = webapi.post("/api/v1/experiments/{}/cancel/".format(name))

    if status != six.moves.http_client.OK:
        logger.error(
            "failed to cancel {} on `{}', reason: {}".format(
                name, webapi.platform, six.moves.http_client.responses[status]
            )
        )


@experiments.command()
@click.argument("name", nargs=1)
@click.pass_context
def reset(ctx, name):
891
    """Reset an experiment on the platform"""
892
893
894
895

    config = ctx.meta.get("config")
    with common.make_webapi(config) as webapi:
        status, _ = webapi.post("/api/v1/experiments/{}/reset/".format(name))
896
897
898
899
900
901
902
903
904
905
906
907

    if status != six.moves.http_client.OK:
        logger.error(
            "failed to reset {} on `{}', reason: {}".format(
                name, webapi.platform, six.moves.http_client.responses[status]
            )
        )


@experiments.command()
@click.argument("name", nargs=1)
@click.pass_context
908
909
def runstatus(ctx, name):
    """Shows the status of an experiment on the platform"""
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930

    config = ctx.meta.get("config")

    with common.make_webapi(config) as webapi:
        fields = ",".join(
            [
                "status",
                "blocks_status",
                "done",
                "errors",
                "execution_info",
                "execution_order",
                "results",
                "started",
                "display_start_date",
                "display_end_date",
            ]
        )
        status, answer = webapi.get(
            "/api/v1/experiments/{}/?fields={}".format(name, fields)
        )
931
932
        if status != six.moves.http_client.OK:
            logger.error(
933
                "failed to get current state of {} on `{}', reason: {}".format(
934
935
936
                    name, webapi.platform, six.moves.http_client.responses[status]
                )
            )
937
938
939
        else:
            data = simplejson.loads(answer)
            print(simplejson.dumps(data, indent=4))
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159


# The monitoring implementation has been inspired from
# https://medium.com/greedygame-engineering/an-elegant-way-to-run-periodic-tasks-in-python-61b7c477b679


class ProgramKilled(Exception):
    """CTRL + C has been used"""

    pass


def signal_handler(signum, frame):
    """Basic signal handler for processing keyboard interruption"""

    raise ProgramKilled


class ExperimentMonitor(threading.Thread):
    """Thread doing the monitoring of an experiment"""

    def __init__(self, interval, config, name):
        super(ExperimentMonitor, self).__init__()

        self.daemon = False
        self.stop_event = threading.Event()
        self.interval = interval
        self.config = config
        self.name = name
        self.stopped = False
        self.queue = queue.Queue()

    def stop(self):
        """Stop the thread cleanly"""

        self.stopped = True
        self.stop_event.set()
        self.join()

    def run(self):
        """Periodically calls the platform instance to get the status of the
        selected experiment.
        """

        fields = ",".join(["status", "blocks_status", "done"])
        self.stopped = False
        with common.make_webapi(self.config) as webapi:
            first_run = True
            while not self.stop_event.wait(
                0 if first_run else self.interval.total_seconds()
            ):

                status, answer = webapi.get(
                    "/api/v1/experiments/{}/?fields={}".format(self.name, fields)
                )
                if status != six.moves.http_client.OK:
                    logger.error(
                        "failed to get current state of {} on `{}', reason: {}".format(
                            self.name,
                            webapi.platform,
                            six.moves.http_client.responses[status],
                        )
                    )
                    self.stop_event.set()
                    self.queue.put({"error": status})
                else:
                    data = simplejson.loads(answer)
                    self.queue.put(data)

                if first_run:
                    first_run = False


def replace_line(pad, line, text):
    """Replaces the content of a ncurses pad line"""

    pad.move(line, 0)
    pad.clrtoeol()
    pad.addstr(line, 0, text)


def process_input(monitor, pad, delta, pad_height, height, width):
    """Processes the keyboard input of an ncurses pad"""

    if pad:
        try:
            ch = pad.getch()
        except curses.error:
            pass
        else:
            if ch == curses.KEY_UP:
                delta = max(delta - 1, 0)
            elif ch == curses.KEY_DOWN:
                delta = min(delta + 1, pad_height - height)
            elif ch == ord("q"):
                monitor.stop()
        pad.refresh(delta, 0, 0, 0, height - 1, width - 1)

    return delta


@experiments.command()
@click.argument("name", nargs=1)
@click.pass_context
def monitor(ctx, name):
    """Monitor a running experiment"""

    config = ctx.meta.get("config")

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

    monitor = ExperimentMonitor(interval=timedelta(seconds=5), config=config, name=name)
    monitor.start()

    stdscr = curses.initscr()
    curses.noecho()
    curses.cbreak()

    initialised = False
    killed = False
    pad = None
    line = 0
    delta = 0
    height, width = stdscr.getmaxyx()
    pad_height = height
    STATIC_LINE_COUNT = 3  # Number of known lines that will be shown

    while True:
        try:
            try:
                data = monitor.queue.get(True, 0.2)
            except queue.Empty:
                delta = process_input(monitor, pad, delta, pad_height, height, width)

            else:
                if "error" in data:
                    killed = True
                    break

                height, width = stdscr.getmaxyx()

                if not initialised:
                    nb_blocks = len(data["blocks_status"]) + STATIC_LINE_COUNT
                    pad_height = max(nb_blocks, height)
                    pad = curses.newpad(pad_height, width)
                    pad.timeout(200)
                    pad.keypad(True)

                line = 0
                replace_line(
                    pad,
                    line,
                    textwrap.shorten("Name: {name}".format(name=name), width=width),
                )
                line += 1
                replace_line(
                    pad,
                    line,
                    textwrap.shorten("Status: {status}".format(**data), width=width),
                )

                blocks = data["blocks_status"]
                text_width = int(width / 2)
                for block_name, block_status in blocks.items():
                    line += 1
                    pad.move(line, 0)
                    pad.clrtoeol()
                    pad.addstr(
                        line,
                        0,
                        textwrap.shorten(
                            "Name: {block_name} ".format(block_name=block_name),
                            width=text_width,
                        ),
                    )
                    pad.addstr(
                        line,
                        text_width,
                        textwrap.shorten(
                            "Status: {block_status}".format(block_status=block_status),
                            width=text_width,
                        ),
                    )

                pad.refresh(delta, 0, 0, 0, height - 1, width - 1)

                if data["done"]:
                    monitor.stop()

                if not initialised:
                    initialised = True

            finally:
                delta = process_input(monitor, pad, delta, pad_height, height, width)

            if not monitor.isAlive():
                break

        except ProgramKilled:
            monitor.stop()
            killed = True
            break

    if not killed:
        line += 1
        pad.timeout(-1)
        pad.addstr(
            line,
            0,
            textwrap.shorten("Experiment done, press any key to leave", width=width),
            curses.A_BOLD,
        )
        pad.move(line, 0)
        pad.refresh(pad_height - height, 0, 0, 0, height - 1, width - 1)
        pad.getkey()

    curses.echo()
    curses.nocbreak()
    curses.endwin()