resourcemodels.py 11.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# vim: set fileencoding=utf-8 :
###############################################################################
#                                                                             #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.editor module of the BEAT platform.           #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################

26
import logging
27

28
29
30
31
32
33
import simplejson as json

from PyQt5.QtSql import QSqlDatabase
from PyQt5.QtSql import QSqlQuery
from PyQt5.QtSql import QSqlTableModel

34
35
from beat.core.database import Database

36
37
38
39
40
41
42
43
from .asset import Asset
from .asset import AssetType
from .assetmodel import AssetModel

PARAMETER_TYPE_KEY = "parameter_type"
DEFAULT_VALUE_KEY = "default_value"
EDITED_KEY = "edited"

44
45
logger = logging.getLogger(__name__)

46
47

# ------------------------------------------------------------------------------
48

49
50
51
52
# Prefix modelization


class ExperimentResources:
53
54
    """Modelization of the experiments resources"""

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    def __init__(self, context=None):
        self.context = context

        database = QSqlDatabase.addDatabase("QSQLITE")
        database.setDatabaseName(":memory:")

        if not database.open():
            raise RuntimeError(
                f"Failed to open database: {database.lastError().text()}"
            )

        self.refresh()

    def setContext(self, context):
        if self.context == context:
            return

        self.context = context
        self.refresh()

    def refresh(self):
        ALGORITHM_TABLE_CLEANUP = "DROP TABLE IF EXISTS algorithms"
        ALGORITHM_TABLE = "CREATE TABLE algorithms(name varchar, type varchar, inputs integer, outputs integer, is_analyzer boolean)"
        INSERT_ALGORITHM = "INSERT INTO algorithms(name, type, inputs, outputs, is_analyzer) VALUES(?, ?, ?, ?, ?)"
79
80
81
        QUEUE_TABLE_CLEANUP = "DROP TABLE IF EXISTS queues"
        QUEUE_TABLE = "CREATE TABLE queues(name varchar, env_name varchar, env_version varchar, env_type varchar)"
        INSERT_QUEUE = "INSERT INTO queues(name, env_name, env_version, env_type) VALUES (?, ?, ?, ?)"
82
83
84
85
        DATASET_TABLE_CLEANUP = "DROP TABLE IF EXISTS datasets"
        DATASET_TABLE = "CREATE TABLE datasets(name varchar, outputs integer)"
        INSERT_DATASET = "INSERT INTO datasets(name, outputs) VALUES(?, ?)"

86
87
        query = QSqlQuery()

Samuel GAIST's avatar
Samuel GAIST committed
88
89
90
91
92
        for query_str in [
            ALGORITHM_TABLE_CLEANUP,
            QUEUE_TABLE_CLEANUP,
            DATASET_TABLE_CLEANUP,
        ]:
93
94
            if not query.exec_(query_str):
                raise RuntimeError(f"Failed to drop table: {query.lastError().text()}")
95

96
        for query_str in [ALGORITHM_TABLE, QUEUE_TABLE, DATASET_TABLE]:
97
98
99
100
101
102
103
            if not query.exec_(query_str):
                raise RuntimeError(
                    f"Failed to create table: {query.lastError().text()}"
                )

        if self.context is None:
            return
104
105
106
107
108
109
110
111
112
113
114
115

        prefix_path = self.context.meta["config"].path
        model = AssetModel()
        model.asset_type = AssetType.ALGORITHM
        model.prefix_path = prefix_path
        model.setLatestOnlyEnabled(False)

        if not query.prepare(INSERT_ALGORITHM):
            raise RuntimeError(f"Failed to prepare query: {query.lastError().text()}")

        for algorithm in model.stringList():
            asset = Asset(prefix_path, AssetType.ALGORITHM, algorithm)
116
117
118
            is_valid, _ = asset.is_valid()
            if not is_valid:
                logger.debug("Skipping invalid algorithm {}".format(algorithm))
119
120
                continue

121
            declaration = asset.declaration
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
            inputs = {}
            outputs = {}
            for group in declaration["groups"]:
                inputs.update(group.get("inputs", {}))
                outputs.update(group.get("outputs", {}))

            query.addBindValue(algorithm)
            query.addBindValue(declaration.get("type", "legacy"))
            query.addBindValue(len(inputs))
            query.addBindValue(len(outputs))
            query.addBindValue("results" in declaration)

            if not query.exec_():
                raise RuntimeError(
                    f"Failed to insert algorithm: {query.lastError().text()}"
                )

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        if not query.prepare(INSERT_QUEUE):
            raise RuntimeError(f"Failed to prepare query: {query.lastError().text()}")

        environments_path = self.context.meta["environments"]
        with open(environments_path, "rt") as file:
            environment_data = json.load(file)

        for item in environment_data.get("remote", []):
            env_name = item["name"]
            env_version = item["version"]
            # import ipdb; ipdb.set_trace()
            for name in item["queues"].keys():
                query.addBindValue(name)
                query.addBindValue(env_name)
                query.addBindValue(env_version)
                query.addBindValue("remote")

                if not query.exec_():
                    raise RuntimeError(
                        f"Failed to insert queue: {query.lastError().text()}"
                    )

        for _, image_info in environment_data.get("docker", {}).items():
            env_name = image_info["name"]
            env_version = image_info["version"]
            query.addBindValue("Local")
            query.addBindValue(env_name)
            query.addBindValue(env_version)
            query.addBindValue("docker")
            if not query.exec_():
                raise RuntimeError(
                    f"Failed to insert queue: {query.lastError().text()}"
                )

173
174
175
176
177
178
179
180
        model.asset_type = AssetType.DATABASE

        if not query.prepare(INSERT_DATASET):
            raise RuntimeError(f"Failed to prepare query: {query.lastError().text()}")

        for database_name in model.stringList():
            database = Database(prefix_path, database_name)
            if not database.valid:
181
                logger.debug("Skipping invalid database: {}".format(database_name))
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
                continue
            protocols = database.protocol_names
            for protocol_name in protocols:
                sets = database.set_names(protocol_name)
                for set_name in sets:
                    set_data = database.set(protocol_name, set_name)
                    name = f"{database_name}/{protocol_name}/{set_name}"
                    output_count = len(set_data["outputs"])
                    query.addBindValue(name)
                    query.addBindValue(output_count)

                    if not query.exec_():
                        raise RuntimeError(
                            f"Failed to insert dataset: {query.lastError().text()}"
                        )

198

199
200
201
experiment_resources = ExperimentResources()


202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
class AlgorithmResourceModel(QSqlTableModel):
    def __init__(self, parent=None):
        super().__init__(parent=parent)

        self._analyzer_enabled = False
        self._input_count = None
        self._output_count = None
        self._types = []

        self.setTable("algorithms")
        self.select()
        self.update()

    def update(self):
        filter_str = f"is_analyzer={self._analyzer_enabled}"
        if self._input_count is not None:
            filter_str += f" AND inputs={self._input_count}"
        if self._output_count is not None:
            filter_str += f" AND outputs={self._output_count}"

        if self._types:
            filter_str += " AND type in ({})".format(
                ",".join([f"'{type_}'" for type_ in self._types])
            )

        self.setFilter(filter_str)

    def setAnalyzerEnabled(self, enabled):
        if self._analyzer_enabled == enabled:
            return

        self._analyzer_enabled = enabled
        self.update()

    def setInputCount(self, count):
        if self._input_count == count:
            return

        self._input_count = count
        self.update()

    def setOutputCount(self, count):
        if self._output_count == count:
            return

        self._output_count = count
        self.update()

    def setTypes(self, type_list):
        if self._types == type_list:
            return

        self._types = type_list
        self.update()
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307


class QueueResourceModel(QSqlTableModel):
    def __init__(self, parent=None):
        super().__init__(parent=parent)

        self._environment = None
        self._version = None
        self._type = None

        self.setTable("queues")
        self.select()
        self.update()

    def update(self):
        filter_str = ""
        if self._environment is not None:
            filter_str += f"env_name='{self._environment}'"

        if self._version is not None:
            if filter_str:
                filter_str += " AND "
            filter_str += f"env_version='{self._version}'"

        if self._type is not None:
            if filter_str:
                filter_str += " AND "
            filter_str += f"env_type='{self._type}'"

        self.setFilter(filter_str)

    def setEnvironment(self, name, version):
        if self._environment == name and self._version == version:
            return

        self._environment = name
        self._version = version
        self.update()

    def setType(self, type_):
        if self._type == type_:
            return

        self._type = type_
        self.update()

    def dump(self):
        print(self.filter())
        for i in range(self.rowCount()):
            print([self.index(i, j).data() for j in range(4)])


308
309
310
311
312
313
314
315
316
317
318
class DatasetResourceModel(QSqlTableModel):
    def __init__(self, parent=None):
        super().__init__(parent=parent)

        self._analyzer_enabled = False
        self._output_count = None

        self.setTable("datasets")
        self.select()

    def update(self):
319
        filter_str = ""
320
321

        if self._output_count is not None:
322
            filter_str = f"outputs={self._output_count}"
323
324
325
326
327
328
329
330
331

        self.setFilter(filter_str)

    def setOutputCount(self, count):
        if self._output_count == count:
            return

        self._output_count = count
        self.update()