resourcemodels.py 6.06 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
# vim: set fileencoding=utf-8 :
###############################################################################
#                                                                             #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.editor module of the BEAT platform.           #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################

import simplejson as json

from PyQt5.QtSql import QSqlDatabase
from PyQt5.QtSql import QSqlQuery
from PyQt5.QtSql import QSqlTableModel

from .asset import Asset
from .asset import AssetType
from .assetmodel import AssetModel


PARAMETER_TYPE_KEY = "parameter_type"
DEFAULT_VALUE_KEY = "default_value"
EDITED_KEY = "edited"


# ------------------------------------------------------------------------------
# Prefix modelization


class ExperimentResources:
    def __init__(self, context=None):
        self.context = context

        database = QSqlDatabase.addDatabase("QSQLITE")
        database.setDatabaseName(":memory:")

        if not database.open():
            raise RuntimeError(
                f"Failed to open database: {database.lastError().text()}"
            )

        self.refresh()

    def setContext(self, context):
        if self.context == context:
            return

        self.context = context
        self.refresh()

    def refresh(self):
        if self.context is None:
            return

        ALGORITHM_TABLE_CLEANUP = "DROP TABLE IF EXISTS algorithms"
        ALGORITHM_TABLE = "CREATE TABLE algorithms(name varchar, type varchar, inputs integer, outputs integer, is_analyzer boolean)"
        INSERT_ALGORITHM = "INSERT INTO algorithms(name, type, inputs, outputs, is_analyzer) VALUES(?, ?, ?, ?, ?)"

        query = QSqlQuery()

        if not query.exec_(ALGORITHM_TABLE_CLEANUP):
            raise RuntimeError(f"Failed to drop table: {query.lastError().text()}")

        if not query.exec_(ALGORITHM_TABLE):
            raise RuntimeError(f"Failed to create table: {query.lastError().text()}")

        prefix_path = self.context.meta["config"].path
        model = AssetModel()
        model.asset_type = AssetType.ALGORITHM
        model.prefix_path = prefix_path
        model.setLatestOnlyEnabled(False)

        if not query.prepare(INSERT_ALGORITHM):
            raise RuntimeError(f"Failed to prepare query: {query.lastError().text()}")

        for algorithm in model.stringList():
            asset = Asset(prefix_path, AssetType.ALGORITHM, algorithm)
            try:
                declaration = asset.declaration
            except json.JSONDecodeError:
                continue

            inputs = {}
            outputs = {}
            for group in declaration["groups"]:
                inputs.update(group.get("inputs", {}))
                outputs.update(group.get("outputs", {}))

            query.addBindValue(algorithm)
            query.addBindValue(declaration.get("type", "legacy"))
            query.addBindValue(len(inputs))
            query.addBindValue(len(outputs))
            query.addBindValue("results" in declaration)

            if not query.exec_():
                raise RuntimeError(
                    f"Failed to insert algorithm: {query.lastError().text()}"
                )


class AlgorithmResourceModel(QSqlTableModel):
    def __init__(self, parent=None):
        super().__init__(parent=parent)

        self._analyzer_enabled = False
        self._input_count = None
        self._output_count = None
        self._types = []

        self.setTable("algorithms")
        self.select()
        self.update()

    def update(self):
        filter_str = f"is_analyzer={self._analyzer_enabled}"
        if self._input_count is not None:
            filter_str += f" AND inputs={self._input_count}"
        if self._output_count is not None:
            filter_str += f" AND outputs={self._output_count}"

        if self._types:
            filter_str += " AND type in ({})".format(
                ",".join([f"'{type_}'" for type_ in self._types])
            )

        self.setFilter(filter_str)

    def setAnalyzerEnabled(self, enabled):
        if self._analyzer_enabled == enabled:
            return

        self._analyzer_enabled = enabled
        self.update()

    def setInputCount(self, count):
        if self._input_count == count:
            return

        self._input_count = count
        self.update()

    def setOutputCount(self, count):
        if self._output_count == count:
            return

        self._output_count = count
        self.update()

    def setTypes(self, type_list):
        if self._types == type_list:
            return

        self._types = type_list
        self.update()