common.py 9.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

###############################################################################
#                                                                             #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.web module of the BEAT platform.              #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################

28
import json
29
30
import os

31
from django.conf import settings
32
33
34
35
36
37
from django.test import TestCase

import beat.core.hash
from beat.core.data import CachedDataSink
from beat.core.database import Database
from beat.core.dataformat import DataFormat
38
39

from ...algorithms.models import Algorithm
40
from ...common.testutils import tearDownModule  # noqa test runner will call it
41
42
43
44
from ...experiments.models import Block
from ...utils.management.commands import install
from ..management.commands import qsetup
from ..models import Environment
45
46
47
48
from ..models import Queue
from ..models import Worker
from ..utils import setup_backend

49
# ----------------------------------------------------------
50
51
52
53
54


ONE_QUEUE_TWO_WORKERS = {
    "queues": {
        "queue": {
55
56
            "memory-limit": 4 * 1024,
            "time-limit": 1440,  # 1 day
57
58
            "cores-per-slot": 1,
            "max-slots-per-user": 2,
59
            "environments": ["Python for tests (1.3.0)"],
60
            "slots": {
61
62
                "node1": {"quantity": 1, "priority": 0},
                "node2": {"quantity": 1, "priority": 0},
63
            },
64
            "groups": ["Default"],
65
66
67
        }
    },
    "workers": {
68
69
        "node1": {"cores": 1, "memory": 16 * 1024},
        "node2": {"cores": 1, "memory": 16 * 1024},
70
71
    },
    "environments": {
72
73
        "Python for tests (1.3.0)": {
            "name": "Python for tests",
74
            "version": "1.3.0",
75
76
            "short_description": "Test",
            "description": "Test environment",
77
            "languages": ["python"],
78
        }
79
80
81
82
    },
}


83
# ----------------------------------------------------------
84
85
86
87
88
89


class BackendUtilitiesMixin(object):
    @classmethod
    def setup_test_data(cls):
        install.create_sites()
90
91
        system_user, plot_user, user = install.create_users("user", "user")
        install.add_group("Default")
92
93
94
95

        setup_backend(qsetup.DEFAULT_CONFIGURATION)

        Worker.objects.update(active=True)
96
        env = Environment.objects.get(name="Python for tests")
97
98
99
        queue = Queue.objects.first()

        template_data = dict(
100
101
102
103
104
105
            system_user=system_user,
            plot_user=plot_user,
            user=user,
            private=False,
            queue=queue.name,
            environment=dict(name=env.name, version=env.version),
106
107
        )

108
        raw_access_db_name = "simple_rawdata_access/1"
109
        source_prefix = os.path.join(settings.BASE_DIR, "src", "beat.examples")
110
111
112
113
114
115
116
117
118
119
        db_root_file_path = os.path.join(settings.PREFIX, "db_root.json")
        db_path = os.path.join(
            settings.PREFIX, "data", raw_access_db_name.replace("/", "_")
        )
        db_root_data = {raw_access_db_name: db_path}

        os.makedirs(db_path, exist_ok=True)

        with open(os.path.join(db_path, "datafile.txt"), "wt") as datafile:
            datafile.write("1")
120

121
122
123
124
125
126
127
        with open(db_root_file_path, "wt") as db_root_file:
            db_root_file.write(json.dumps(db_root_data))

        for contribution in ["system", "test"]:
            install.install_contributions(
                source_prefix, contribution, template_data, db_root_file_path
            )
128

129
        if not os.path.exists(settings.CACHE_ROOT):
130
            os.mkdir(settings.CACHE_ROOT)
131

132
133
        os.remove(db_root_file_path)

134
135
136
    def clean_cache(self):
        for p, dirs, files in os.walk(settings.CACHE_ROOT, topdown=False):

137
138
            files = [f for f in files if not f.startswith(".")]
            dirs[:] = [d for d in dirs if not d.startswith(".")]  # note: in-place
139
140
141
142
143
144
145
146
147

            for f in files:
                fullpath = os.path.join(p, f)
                os.remove(fullpath)

            for d in dirs:
                fullpath = os.path.join(p, d)
                os.rmdir(fullpath)

148
149
150
    def set_experiment_state(
        self, experiment, experiment_status=None, block_status=None, cache_status=None
    ):
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        if block_status:
            for name, status in block_status.items():
                block = experiment.blocks.get(name=name)
                block.status = status
                block.save()

        if cache_status:
            for name, status in cache_status.items():
                block = experiment.blocks.get(name=name)
                for cached_file in block.outputs.all():
                    cached_file.status = status
                    cached_file.save()

        if experiment_status:
            experiment.status = experiment_status
            experiment.save()

    def generate_cached_files(self, hash, splits):
169
        dataformat = DataFormat(settings.PREFIX, "system/integer/1")
170
171
172
173
174
175
176
177

        path = os.path.join(settings.CACHE_ROOT, beat.core.hash.toPath(hash))
        os.makedirs(os.path.dirname(path))

        value = 0

        for index, split in enumerate(splits):
            sink = CachedDataSink()
178
179
180
181
182

            start_data_index = split[0][0] if isinstance(split[0], tuple) else split[0]
            end_data_index = split[-1][1] if isinstance(split[-1], tuple) else split[-1]

            sink.setup(path, dataformat, start_data_index, end_data_index)
183
184
185
186
187
188
189
190
191

            for indices in split:
                if not isinstance(indices, tuple):
                    start = indices
                    end = indices
                else:
                    start = indices[0]
                    end = indices[1]

192
                sink.write({"value": value}, start_data_index=start, end_data_index=end)
193
194
195
196
197

                value += 1

            sink.close()

198
    def prepare_databases(self, configuration):
199
200
201
202
203
204
205
        for _, cfg in configuration["datasets"].items():
            path = beat.core.hash.toPath(
                beat.core.hash.hashDataset(
                    cfg["database"], cfg["protocol"], cfg["set"]
                ),
                suffix=".db",
            )
206
207

            if not os.path.exists(os.path.join(settings.CACHE_ROOT, path)):
208
209
                database = Database(settings.PREFIX, cfg["database"])
                view = database.view(cfg["protocol"], cfg["set"])
210
211
212
                view.index(os.path.join(settings.CACHE_ROOT, path))


213
# ----------------------------------------------------------
214
215
216
217
218


class BaseBackendTestCase(TestCase, BackendUtilitiesMixin):
    @classmethod
    def setUpTestData(cls):
219
        cls.setup_test_data()
220
221
222
223
224
225
226
227

    def setUp(self):
        self.clean_cache()

    def tearDown(self):
        self.clean_cache()

    def check_single(self, xp):
228
        """Checks user/user/single/1/single"""
229
230
231
232
233

        self.assertEqual(xp.blocks.count(), 2)

        b0 = xp.blocks.all()[0]

234
        self.assertEqual(b0.name, "echo")
235
        self.assertEqual(b0.status, Block.PENDING)
236
        self.assertEqual(b0.algorithm, Algorithm.objects.get(name="integers_echo"))
237
238
        self.assertEqual(b0.dependencies.count(), 0)
        self.assertEqual(b0.dependents.count(), 1)
239
        self.assertEqual(b0.queue.name, "queue")
240
        self.assertEqual(b0.environment.name, "Python for tests")
241
242
243
        self.assertEqual(b0.required_slots, 1)
        self.assertEqual(b0.inputs.count(), 1)
        self.assertEqual(b0.outputs.count(), 1)
244
        self.assertEqual(b0.job.splits.count(), 0)  # not scheduled yet
245

246
        self.assertFalse(b0.done())
247
248
249

        b1 = xp.blocks.all()[1]

250
        self.assertEqual(b1.name, "analysis")
251
        self.assertEqual(b1.status, Block.PENDING)
252
253
254
        self.assertEqual(
            b1.algorithm, Algorithm.objects.get(name="integers_echo_analyzer")
        )
255
256
        self.assertEqual(b1.dependencies.count(), 1)
        self.assertEqual(b1.dependents.count(), 0)
257
        self.assertEqual(b1.queue.name, "queue")
258
        self.assertEqual(b1.environment.name, "Python for tests")
259
260
261
        self.assertEqual(b1.required_slots, 1)
        self.assertEqual(b1.inputs.count(), 1)
        self.assertEqual(b1.outputs.count(), 1)
262
        self.assertEqual(b1.job.splits.count(), 0)  # not scheduled yet
263

264
        self.assertFalse(b1.done())