toolchaineditor.py 7.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# vim: set fileencoding=utf-8 :
###############################################################################
#                                                                             #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.editor module of the BEAT platform.           #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################

26
27
28
29
30
31
32
33
34
35
36
37
38
39
from PyQt5.QtCore import Qt

from PyQt5.QtGui import QImage
from PyQt5.QtGui import QPixmap
from PyQt5.QtGui import QPalette

from PyQt5.QtWidgets import QLabel
from PyQt5.QtWidgets import QScrollArea
from PyQt5.QtWidgets import QVBoxLayout
from PyQt5.QtWidgets import QWidget

from graphviz import Digraph
from itertools import zip_longest

40
from ..backend.asset import AssetType
41
from ..decorators import frozen
Samuel GAIST's avatar
Samuel GAIST committed
42

43
44
45
from .editor import AbstractAssetEditor


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class SimpleToolchainPreview(QWidget):
    """Basic toolchain preview"""

    def __init__(self, parent=None):
        super().__init__(parent=parent)

        self.json_object = {}

        self.label = QLabel()
        self.label.setBackgroundRole(QPalette.Base)
        scrollarea = QScrollArea()
        scrollarea.setAlignment(Qt.AlignHCenter | Qt.AlignVCenter)
        scrollarea.setWidget(self.label)
        layout = QVBoxLayout(self)
        layout.addWidget(scrollarea)

    def load(self, json_object):
        """Parse the json in parameter and generates a graph"""

        self.json_object = json_object

        graph = Digraph("toolchain", format="png")
        graph.attr(rankdir="LR")
        graph.node_attr["shape"] = "plaintext"

        def build_rows(inputs, outputs):
            input_cell = "<td bgcolor='#ffff00' port='in_{input}'>{input}</td>"
            output_cell = "<td bgcolor='#00ffff' port='out_{output}'>{output}</td>"
            empty_cell = "<td border='0'></td>"

            rows = ""
            if inputs and outputs:
                for input_, output in zip_longest(inputs, outputs):
                    rows += "<tr>"
                    if input_:
                        rows += input_cell.format(input=input_)
                    else:
                        rows += empty_cell
                    if output:
                        rows += output_cell.format(output=output)
                    else:
                        rows += empty_cell
                    rows += "</tr>"
            elif inputs:
                for input_ in inputs:
                    rows += "<tr>"
                    rows += input_cell.format(input=input_)
                    rows += "</tr>"
            elif outputs:
                for output in outputs:
                    rows += "<tr>"
                    rows += output_cell.format(output=output)
                    rows += "</tr>"
            return rows

        def build_block_table(block, background):
            block_name = block["name"]
            inputs = block.get("inputs", {})
            outputs = block.get("outputs", {})

            columns = 2 if inputs and outputs else 1

            label = f"<<table border='1' cellborder='1' bgcolor='{background}'>"
            label += f"<tr><td colspan='{columns}' border='0'>{block_name}</td></tr>"
            label += build_rows(inputs, outputs)
            label += "</table>>"
            return block_name, label

        DATASET_COLOR = "#fffd85"
        BLOCK_COLOR = "#d4d4cd"
        ANALYZER_COLOR = "#8cecf5"

118
119
120
        with graph.subgraph(name="cluster_datasets") as dg:
            dg.attr(color="white")

121
            for block in json_object.get("datasets", []):
122
123
124
125
126
127
                name, label = build_block_table(block, DATASET_COLOR)
                dg.node(name=name, label=label)

        with graph.subgraph(name="cluster_blocks") as bg:
            bg.attr(color="white")

128
            for block in json_object.get("blocks", []):
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
                name, label = build_block_table(block, BLOCK_COLOR)
                bg.node(name=name, label=label)

            for block in json_object.get("loops", []):
                block_name = block["name"]
                label = f"<<table border='1' cellborder='1' bgcolor='{BLOCK_COLOR}'>"
                label += f"<tr><td colspan='2' border='0'>{block_name}</td></tr>"
                for prefix in ["processor", "evaluator"]:
                    label += (
                        "<tr><td><table border='0' cellborder='1' bgcolor='#def2a7'>"
                    )
                    label += f"<tr><td colspan='2' border='0'>{prefix}</td></tr>"
                    label += build_rows(
                        block.get(f"{prefix}_inputs", {}),
                        block.get(f"{prefix}_outputs", {}),
                    )
                    label += "</table></td></tr>"
                label += "</table>>"

                bg.node(block_name, label=label)

        with graph.subgraph(name="cluster_analyzers") as ag:
            ag.attr(color="white")

153
            for block in json_object.get("analyzers", []):
154
155
                name, label = build_block_table(block, ANALYZER_COLOR)
                ag.node(name=name, label=label)
156

157
        channel_colors = json_object.get("representation", {}).get("channel_colors", {})
158

159
        for connection in json_object.get("connections", []):
160
161
            from_block, output = connection["from"].split(".")
            to_block, input_ = connection["to"].split(".")
162
            channel_color = channel_colors.get(connection["channel"], "black")
163

164
165
166
167
168
            graph.edge(
                f"{from_block}:out_{output}:e",
                f"{to_block}:in_{input_}:w",
                color=channel_color,
            )
169
170
171
172
173
174
175
176
177
178

        self.label.setPixmap(QPixmap.fromImage(QImage.fromData(graph.pipe())))
        self.label.adjustSize()

    def dump(self):
        """Returns the json used to load the widget"""

        return self.json_object


Samuel GAIST's avatar
Samuel GAIST committed
179
@frozen
180
181
class ToolchainEditor(AbstractAssetEditor):
    def __init__(self, parent=None):
182
        super().__init__(AssetType.TOOLCHAIN, parent)
Samuel GAIST's avatar
Samuel GAIST committed
183
        self.setObjectName(self.__class__.__name__)
184
185
        self.set_title(self.tr("Toolchain"))

186
187
188
189
        self.preview = SimpleToolchainPreview()
        self.layout().addWidget(self.preview, 2)
        self.layout().addStretch()

190
191
    def _load_json(self, json_object):
        """Load the json object passed as parameter"""
192
193

        self.preview.load(json_object)
194
195
196

    def _dump_json(self):
        """Returns the json representation of the asset"""
197
198

        return self.preview.dump()