core.py 16.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###################################################################################
#                                                                                 #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/               #
# Contact: beat.support@idiap.ch                                                  #
#                                                                                 #
# Redistribution and use in source and binary forms, with or without              #
# modification, are permitted provided that the following conditions are met:     #
#                                                                                 #
# 1. Redistributions of source code must retain the above copyright notice, this  #
# list of conditions and the following disclaimer.                                #
#                                                                                 #
# 2. Redistributions in binary form must reproduce the above copyright notice,    #
# this list of conditions and the following disclaimer in the documentation       #
# and/or other materials provided with the distribution.                          #
#                                                                                 #
# 3. Neither the name of the copyright holder nor the names of its contributors   #
# may be used to endorse or promote products derived from this software without   #
# specific prior written permission.                                              #
#                                                                                 #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED   #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE          #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE    #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL      #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR      #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER      #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,   #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE   #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.            #
#                                                                                 #
###################################################################################


"""
Base class for asset testing
"""

41
import os
42
43
import nose.tools
import click
44
import shutil
45

46
from collections import namedtuple
47
from functools import wraps
48

49
50
51
from click.testing import CliRunner

from beat.core.test.utils import cleanup
52
from beat.core.test.utils import skipif
53
from beat.core.test.utils import slow
54
55
56
57
from beat.cmdline.scripts import main_cli

from .. import common

58
59
60
61
62
63
64
65
66
67
68
69
70
71
from . import platform, disconnected, prefix, tmp_prefix, user, token

if not disconnected:
    from django.contrib.staticfiles.testing import LiveServerTestCase
else:

    class LiveServerTestCase:
        """Dummy shell class"""

        live_server_url = None

        @classmethod
        def setUpClass(cls):
            pass
72

73
74
75
        def setUp(self):
            pass

76

77
78
79
# ----------------------------------------------------------
# decorators

80
81
82
# Make skip on disconnected a decorator, this will make tests easier to read and write
skip_disconnected = skipif(disconnected, "missing test platform (%s)" % platform)

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

def skip_no_version(method):
    """Skip test is asset does not support versioning"""

    @wraps(method)
    def _impl(self, *args, **kwargs):
        with common.Selector(tmp_prefix) as selector:
            if not selector.has_versions(self.asset_type):
                raise nose.SkipTest(
                    "{} does not support versions".format(self.asset_type)
                )
        return method(self, *args, **kwargs)

    return _impl


def skip_no_fork(method):
    """Skip test if asset does not support forking"""

    @wraps(method)
    def _impl(self, *args, **kwargs):
        with common.Selector(tmp_prefix) as selector:
            if not selector.can_fork(self.asset_type):
                raise nose.SkipTest("{} does not support forks".format(self.asset_type))
        return method(self, *args, **kwargs)

    return _impl


# ----------------------------------------------------------
# helper

115
116
117
# Used for making direct calls
MockConfig = namedtuple("MockConfig", ["platform", "user", "token"])

118

119
120
121
# ----------------------------------------------------------


122
class BaseTest:
123
124
    asset_type = None

125
126
127
128
    def setUp(self):
        pass

    def tearDown(self):
129
130
        cleanup()

131
132
133
134
135
136
137
138
139
140
141
    @classmethod
    def get_cmd_group(cls, asset_type):
        try:
            cmd_group = common.TYPE_PLURAL[asset_type]
        except KeyError:
            return asset_type

        if "/" in cmd_group:
            cmd_group = cmd_group.split("/")[-1]
        return cmd_group

142
143
    @classmethod
    def call(cls, *args, **kwargs):
144
145
146
147
148
        """A central mechanism to call the main routine with the right parameters"""

        use_prefix = kwargs.get("prefix", prefix)
        use_platform = kwargs.get("platform", platform)
        use_cache = kwargs.get("cache", "cache")
149
        asset_type = kwargs.get("asset_type", cls.asset_type)
150
        remote_user = kwargs.get("user", user)
151

152
        cmd_group = cls.get_cmd_group(asset_type)
153

154
155
156
157
158
159
160
        parameters = [
            "--test-mode",
            "--prefix",
            use_prefix,
            "--token",
            token,
            "--user",
161
            remote_user,
162
163
164
165
166
167
168
169
170
171
172
            "--cache",
            use_cache,
            "--platform",
            use_platform,
        ]

        if cmd_group:
            parameters.append(cmd_group)

        parameters += list(args)

173
174
        runner = CliRunner()
        with runner.isolated_filesystem():
175
            result = runner.invoke(main_cli.main, parameters, catch_exceptions=False)
176
177
178
179
180

        if result.exit_code != 0:
            click.echo(result.output)
        return result.exit_code, result.output

181

182
183
184
# ----------------------------------------------------------


185
186
187
class AssetBaseTest(BaseTest):
    """Base class that ensures that the asset_type is set before calling click"""

188
189
190
191
192
193
194
195
196
197
198
199
    object_map = {}
    storage_cls = None

    @classmethod
    def create(cls, obj=None):
        obj = obj or cls.object_map["create"]
        exit_code, outputs = cls.call("create", obj, prefix=tmp_prefix)
        nose.tools.eq_(exit_code, 0, outputs)
        storage = cls.storage_cls(tmp_prefix, obj)
        nose.tools.assert_true(storage.exists())
        return storage

200
201
202
203
204
205
206
    @classmethod
    def delete(cls, obj):
        exit_code, outputs = cls.call("rm", obj, prefix=tmp_prefix)
        nose.tools.eq_(exit_code, 0, outputs)
        storage = cls.storage_cls(tmp_prefix, obj)
        nose.tools.assert_false(storage.exists())

207
208
209
210
211
212
    @classmethod
    def call(cls, *args, **kwargs):
        nose.tools.assert_is_not_none(cls.asset_type, "Missing value for asset_type")
        return super().call(*args, **kwargs)


213
class AssetLocalTest(AssetBaseTest):
214
    """Base class for local tests"""
215
216
217
218

    def __init__(self):
        super().__init__()
        nose.tools.assert_true(self.object_map)
219
        nose.tools.assert_is_not_none(self.storage_cls)
220

221
222
223
224
225
226
227
228
229
230
231
232
    def test_local_list(self):
        exit_code, outputs = self.call("list")
        nose.tools.eq_(exit_code, 0, outputs)

    def test_check_valid(self):
        exit_code, outputs = self.call("check", self.object_map["valid"])
        nose.tools.eq_(exit_code, 0, outputs)

    def test_check_invalid(self):
        exit_code, outputs = self.call("check", self.object_map["invalid"])
        nose.tools.eq_(exit_code, 1, outputs)

233
234
235
    def test_create(self, obj=None):
        self.create(self.object_map["create"])

236
    @skip_no_version
237
238
239
    def test_new_version(self):
        obj = self.object_map["create"]
        obj2 = self.object_map["new"]
240
        self.create(obj)
241
242
243
244
245
246
247
248
249
250
251
252
        exit_code, outputs = self.call("version", obj, prefix=tmp_prefix)
        nose.tools.eq_(exit_code, 0, outputs)
        s = self.storage_cls(tmp_prefix, obj2)
        nose.tools.assert_true(s.exists())

        # check version status
        with common.Selector(tmp_prefix) as selector:
            nose.tools.eq_(selector.version_of(self.asset_type, obj2), obj)

    def test_fork(self):
        obj = self.object_map["create"]
        obj2 = self.object_map["fork"]
253
        self.create(obj)
254
        with common.Selector(tmp_prefix) as selector:
255
256
257
258
259
260
261
262
263
264
            if selector.can_fork(self.asset_type):
                exit_code, outputs = self.call("fork", obj, obj2, prefix=tmp_prefix)
                nose.tools.eq_(exit_code, 0, outputs)
                selector.load()
                s = self.storage_cls(tmp_prefix, obj2)
                nose.tools.assert_true(s.exists())
                nose.tools.eq_(selector.forked_from(self.asset_type, obj2), obj)
            else:
                exit_code, outputs = self.call("fork", obj, obj2, prefix=tmp_prefix)
                nose.tools.assert_not_equal(exit_code, 0)
265
266
267

    def test_delete_local(self):
        obj = self.object_map["create"]
268
269
        self.create(obj)
        self.delete(obj)
270
271
272
273
274
275
276
277
278

    def test_delete_local_unexisting(self):
        obj = self.object_map["create"]
        storage = self.storage_cls(tmp_prefix, obj)
        nose.tools.assert_false(storage.exists())

        exit_code, outputs = self.call("rm", obj, prefix=tmp_prefix)
        nose.tools.eq_(exit_code, 1, outputs)
        nose.tools.assert_false(storage.exists())
279
280


281
282
283
# ----------------------------------------------------------


284
class AssetRemoteTest(AssetBaseTest):
285
    """Base class for remote tests"""
286
287
288
289

    def __init__(self):
        super().__init__()
        nose.tools.assert_true(self.object_map)
290
        nose.tools.assert_is_not_none(self.storage_cls)
291

292
293
    def _modify_asset(self, asset_name):
        """Modify an asset"""
Samuel GAIST's avatar
Samuel GAIST committed
294

295
296
        raise NotImplementedError

Samuel GAIST's avatar
Samuel GAIST committed
297
298
299
300
301
302
303
    def _prepare_fork_dependencies(self, asset_name):
        """Prepare prefix content with fork dependencies"""

        src_storage = self.storage_cls(prefix, asset_name)
        dst_storage = self.storage_cls(tmp_prefix, asset_name)
        dst_storage.save(*src_storage.load())

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    @slow
    @skip_disconnected
    def test_remote_list(self):
        exit_code, output = self.call("list", "--remote")
        nose.tools.eq_(exit_code, 0, output)

    @slow
    @skip_disconnected
    def test_pull_one(self, obj=None):
        obj = obj or self.object_map["pull"]
        exit_code, output = self.call("pull", obj, prefix=tmp_prefix)
        nose.tools.eq_(exit_code, 0, output)
        storage = self.storage_cls(tmp_prefix, obj)
        nose.tools.assert_true(storage.exists())
        return storage

    @slow
    @skip_disconnected
    def test_pull_all(self):
        exit_code, output = self.call("pull", prefix=tmp_prefix)
        nose.tools.eq_(exit_code, 0, output)

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
    @slow
    @skip_disconnected
    def test_diff(self):
        obj = self.object_map["diff"]
        exit_code, output = self.call("pull", obj, prefix=tmp_prefix)
        nose.tools.eq_(exit_code, 0, output)

        # quickly modify the user library by emptying it
        self._modify_asset(obj)

        exit_code, output = self.call("diff", obj, prefix=tmp_prefix)
        nose.tools.eq_(exit_code, 0, output)

    @slow
    @skip_disconnected
    def test_status(self):
        self.test_diff()
        self.test_pull_one()
        exit_code, output = self.call("status", prefix=tmp_prefix)
        nose.tools.eq_(exit_code, 0, output)

347
348
    @slow
    @skip_disconnected
349
    @skip_no_version
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    def test_push_different_versions(self):
        asset_name = self.object_map["create"]
        self.create(asset_name)

        number_of_versions = 5
        version_pos = asset_name.rindex("/") + 1
        original_name = asset_name[:version_pos]
        original_version = int(asset_name[version_pos:])

        for i in range(number_of_versions):
            asset_name = original_name + str(original_version + i)
            exit_code, outputs = self.call("version", asset_name, prefix=tmp_prefix)
            nose.tools.eq_(exit_code, 0, outputs)

        asset_name = original_name + str(original_version + number_of_versions)

        exit_code, output = self.call("push", asset_name, prefix=tmp_prefix)

        nose.tools.eq_(exit_code, 0, output)

        config = MockConfig(self.live_server_url, user, token)
        with common.make_webapi(config) as webapi:
            asset_list = common.retrieve_remote_list(webapi, self.asset_type, ["name"])
            aoi_list = [
                asset for asset in asset_list if asset["name"].startswith(original_name)
            ]
            nose.tools.assert_equal(len(aoi_list), number_of_versions + 1)

Samuel GAIST's avatar
Samuel GAIST committed
378
379
    @slow
    @skip_disconnected
380
    @skip_no_fork
Samuel GAIST's avatar
Samuel GAIST committed
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    def test_push_different_forks(self):
        asset_name = self.object_map["fork"]

        if "fork_from" in self.object_map:
            original = self.object_map["fork_from"]
        else:
            original = self.object_map["pull"]

        self._prepare_fork_dependencies(original)

        number_of_forks = 1
        if asset_name.count("/") != 4:
            version_pos = asset_name.rindex("/") + 1
            original_name = asset_name[: version_pos - 1]
        else:
            original_name = asset_name

        for i in range(number_of_forks):
            fork_name = "{}_{}".format(original_name, i)
            if asset_name.count("/") != 4:
                fork_name += "/1"

            exit_code, outputs = self.call(
                "fork", original, fork_name, prefix=tmp_prefix
            )
            nose.tools.eq_(exit_code, 0, outputs)
            asset_name = fork_name

        asset_name = "{}_{}".format(original_name, number_of_forks - 1)
        if asset_name.count("/") != 4:
            asset_name += "/1"

        exit_code, output = self.call("push", asset_name, prefix=tmp_prefix)

        nose.tools.eq_(exit_code, 0, output)

        config = MockConfig(self.live_server_url, user, token)
        with common.make_webapi(config) as webapi:
            asset_list = common.retrieve_remote_list(webapi, self.asset_type, ["name"])
            aoi_list = [
                asset for asset in asset_list if asset["name"].startswith(original_name)
            ]
            nose.tools.assert_equal(len(aoi_list), number_of_forks)

425
426
427
    @slow
    @skip_disconnected
    def test_push_and_delete(self):
Samuel GAIST's avatar
Samuel GAIST committed
428
        asset_name = self.object_map["push"]
429
430

        # now push the new object and then delete it remotely
Samuel GAIST's avatar
Samuel GAIST committed
431
        exit_code, output = self.call("push", asset_name)
432
        nose.tools.eq_(exit_code, 0, output)
Samuel GAIST's avatar
Samuel GAIST committed
433
        exit_code, output = self.call("rm", "--remote", asset_name)
434
435
        nose.tools.eq_(exit_code, 0, output)

436
437
438
439
440
441
442
443
444
445
446
    @slow
    @skip_disconnected
    def test_fail_push_invalid(self):
        asset_name = self.object_map["push_invalid"]

        with nose.tools.assert_raises(RuntimeError) as assertion:
            self.call("push", asset_name, user="errors")
        exc = assertion.exception
        text = exc.args[0]
        nose.tools.assert_true(text.startswith("Invalid "))

Samuel GAIST's avatar
Samuel GAIST committed
447
448
449
450
451
452
453
454
    @slow
    @skip_disconnected
    def test_fail_not_owner_push(self):
        asset_name = self.object_map["not_owner_push"]

        exit_code, output = self.call("push", asset_name)
        nose.tools.eq_(exit_code, 1, output)

455

456
457
class OnlineTestMixin:
    """Mixin for using Django's live server"""
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476

    def setUp(self):
        """Cache a copy of the database to avoid the need to call make install
        on each tests.
        """

        if not disconnected:
            from django.conf import settings

            database_path = settings.DATABASES["default"]["TEST"]["NAME"]
            db_backup = os.path.join(prefix, "django_test_database.sqlite3")

            if not os.path.exists(db_backup):
                shutil.copyfile(database_path, db_backup)
            else:
                shutil.copyfile(db_backup, database_path)

    @classmethod
    def call(cls, *args, **kwargs):
477
        """Re-implement for platform URL handling"""
478

479
        kwargs["platform"] = cls.live_server_url
480

481
        return super().call(*args, **kwargs)
482
483
484
485
486
487
488
489
490
491


class OnlineTestCase(LiveServerTestCase, OnlineTestMixin, BaseTest):
    """Test case using django live server for test of remote functions"""

    def setUp(self):
        for base in OnlineTestCase.__bases__:
            base.setUp(self)


492
class OnlineAssetTestCase(LiveServerTestCase, OnlineTestMixin, AssetRemoteTest):
493
494
495
496
497
    """Test case using django live server for asset related remote tests"""

    def setUp(self):
        for base in OnlineTestCase.__bases__:
            base.setUp(self)