toolchaineditor.py 29.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
# vim: set fileencoding=utf-8 :
###############################################################################
#                                                                             #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.editor module of the BEAT platform.           #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################

26 27 28
import simplejson as json

from enum import Enum
29

30 31 32 33 34 35 36 37 38 39 40 41
from PyQt5.QtCore import Qt
from PyQt5.QtCore import QRect
from PyQt5.QtCore import QRectF
from PyQt5.QtCore import QPointF
from PyQt5.QtCore import pyqtSignal

from PyQt5.QtGui import QColor
from PyQt5.QtGui import QBrush
from PyQt5.QtGui import QPen
from PyQt5.QtGui import QFont
from PyQt5.QtGui import QFontMetrics
from PyQt5.QtGui import QPainterPath
42
from PyQt5.QtGui import QTransform
43
from PyQt5.QtGui import QIcon
44

45
from PyQt5.QtWidgets import QHBoxLayout
46
from PyQt5.QtWidgets import QWidget
47 48 49 50
from PyQt5.QtWidgets import QGraphicsView
from PyQt5.QtWidgets import QGraphicsItem
from PyQt5.QtWidgets import QGraphicsPathItem
from PyQt5.QtWidgets import QGraphicsObject
51 52
from PyQt5.QtWidgets import QToolBar
from PyQt5.QtWidgets import QAction
53

54
from ..backend.asset import AssetType
55
from ..backend.assetmodel import AssetModel
56
from ..decorators import frozen
Samuel GAIST's avatar
Samuel GAIST committed
57

58
from .editor import AbstractAssetEditor
59 60 61
from .drawing_space import DrawingSpace


62
class BasePin(QGraphicsObject):
63 64
    """Base class for pin graphics"""

65 66
    dataChanged = pyqtSignal()

67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
    def __init__(self, parent, pin, block, pin_brush, pin_pen):

        super().__init__(parent=parent)

        # Highlight
        self.setAcceptHoverEvents(True)

        # Storage
        self.pin_type = None
        self.pin = pin
        self.block = block
        self.brush = pin_brush
        self.pen = pin_pen
        self.block_object = parent

    def shape(self):
        """Define the circle shape of a Pin object"""

        path = QPainterPath()
        path.addRect(self.boundingRect())
        return path

    def paint(self, painter, option, widget):
        """Paint the Pin"""

        painter.setBrush(self.brush)
        painter.setPen(self.pen)

        painter.drawEllipse(self.boundingRect())

97
    def mousePressEvent(self, event):
98
        """Painting connection initiated"""
99 100 101 102 103

        self.new_connection = Connection(self.block_object.connection_style)
        self.block_object.scene().addItem(self.new_connection)

    def mouseMoveEvent(self, event):
104
        """Painting connection in progress"""
105

106 107 108 109 110 111 112 113 114 115 116 117
        # Only one single connection allowed from input pin
        if isinstance(self, InputPin):
            # Check if connection exist and remove if it does
            for connection in self.block_object.toolchain.connections:
                if (
                    connection.end_block == self.block_object
                    and connection.end_pin == self
                ):
                    self.block_object.toolchain.connections.remove(connection)
                    self.block_object.scene().removeItem(connection)
                    self.dataChanged.emit()

118 119 120 121
        mouse_position = self.mapToScene(event.pos())
        self.new_connection.set_new_connection_pins_coordinates(self, mouse_position)

    def mouseReleaseEvent(self, event):
122
        """Painting connection ended - validation required"""
123 124 125 126 127 128

        self.block_object.scene().removeItem(self.new_connection)
        target = self.block_object.scene().itemAt(
            event.scenePos().toPoint(), QTransform()
        )

129
        if isinstance(target, BasePin):
130

131 132 133
            if isinstance(self, OutputPin):
                start = self
                end = target
134
            else:
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
                start = target
                end = self

            if Connection(self.block_object.connection_style).check_validity(
                start, end
            ):

                # Find the corresponding channel
                connection_settings = {}
                if start.block_object.type == BlockType.DATASETS.name:
                    connection_settings["channel"] = start.block_object.name
                else:
                    connection_settings[
                        "channel"
                    ] = start.block_object.synchronized_channel

                # Create the connection
                connection_settings["from"] = start.block + "." + start.pin
                connection_settings["to"] = end.block + "." + end.pin

                channel_colors = self.block_object.toolchain.json_object[
                    "representation"
                ]["channel_colors"]

                connection = Connection(self.block_object.connection_style)
                connection.load(
                    self.block_object.toolchain, connection_settings, channel_colors
                )
163 164 165

                self.dataChanged.emit()

166 167
                self.block_object.toolchain.connections.append(connection)
                self.block_object.toolchain.scene.addItem(connection)
168

169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
    def get_center_point(self):
        """Get the center coordinates of the Pin(x,y)"""

        rect = self.boundingRect()
        pin_center_point = QPointF(
            rect.x() + rect.width() / 2.0, rect.y() + rect.height() / 2.0
        )

        return self.mapToScene(pin_center_point)


class InputPin(BasePin):
    def __init__(self, parent, pin, block, pin_brush, pin_pen):

        super().__init__(parent, pin, block, pin_brush, pin_pen)

    def boundingRect(self):
        """Bounding rect around pin object"""

        height = self.block_object.height / 2.0
        width = height

        x = -(width / 2.0)
        y = (
            self.block_object.height
            + self.block_object.inputs.index(self.pin) * self.block_object.height
        )

        rect = QRectF(QRect(x, y, width, height))
        return rect


class OutputPin(BasePin):
    def __init__(self, parent, pin, block, pin_brush, pin_pen):

        super().__init__(parent, pin, block, pin_brush, pin_pen)

    def boundingRect(self):
        """ bounding rect width by height.
        """

        height = self.block_object.height / 2.0
        width = height

        x = self.block_object.custom_width - (width / 2.0)
        y = (
            self.block_object.height
            + self.block_object.outputs.index(self.pin) * self.block_object.height
        )

        rect = QRectF(QRect(x, y, width, height))
        return rect


class Connection(QGraphicsPathItem):
224
    def __init__(self, style):
225 226 227

        super().__init__()

228 229 230 231 232 233 234
        self.start_block_name = None
        self.start_pin_name = None
        self.start_pin_center_point = None
        self.end_block_name = None
        self.end_pin_name = None
        self.end_pin_center_point = None
        self.channel = None
235

236
        self.connection_color = []
237 238 239 240 241 242 243 244

        self.set_style(style)

    def set_style(self, config):

        # Highlight
        self.setAcceptHoverEvents(True)

245 246
        # Geometry and color settings
        self.connection_color = config["color"]
247

248
        self.connection_pen = QPen()
249
        self.connection_pen.setColor(QColor(*self.connection_color))
250
        self.connection_pen.setWidth(config["width"])
251

252 253 254 255 256
    def drawCubicBezierCurve(self):

        self.setPen(self.connection_pen)

        path = QPainterPath()
257 258 259 260 261 262
        middle_point_x = (
            self.end_pin_center_point.x() - self.start_pin_center_point.x()
        ) / 2.0
        middle_point_y = (
            self.end_pin_center_point.y() - self.start_pin_center_point.y()
        ) / 2.0
263
        second_middle_point_y = (
264
            self.end_pin_center_point.y() - self.start_pin_center_point.y()
265 266 267
        ) / 4.0
        control_point = QPointF(middle_point_x, middle_point_y)
        second_control_point = QPointF(middle_point_x, second_middle_point_y)
268
        path.moveTo(self.start_pin_center_point)
269
        path.cubicTo(
270 271 272
            self.start_pin_center_point + control_point,
            self.end_pin_center_point - second_control_point,
            self.end_pin_center_point,
273 274 275 276
        )

        self.setPath(path)

277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
    def set_moved_block_pins_coordinates(self):
        self.start_pin_center_point = self.start_pin.get_center_point()
        self.end_pin_center_point = self.end_pin.get_center_point()

        self.drawCubicBezierCurve()

    def set_new_connection_pins_coordinates(self, selected_pin, mouse_position):

        if isinstance(selected_pin, OutputPin):
            self.start_block_name = selected_pin.block
            self.start_pin_name = selected_pin.pin
            self.start_pin = selected_pin
            self.start_pin_center_point = self.start_pin.get_center_point()
            self.end_pin_center_point = mouse_position
        if isinstance(selected_pin, InputPin):
            self.end_block_name = selected_pin.block
            self.end_pin_name = selected_pin.pin
            self.end_pin = selected_pin
            self.end_pin_center_point = self.end_pin.get_center_point()
            self.start_pin_center_point = mouse_position

        self.drawCubicBezierCurve()

300 301 302 303 304 305 306 307 308 309 310 311 312
    def check_validity(self, start, end):

        # remove input-input and output-output connection
        if type(start) == type(end):
            return False

        # remove connection to same block object
        if start.block_object == end.block_object:
            return False

        # check if end block pin is free
        toolchain = end.block_object.toolchain
        for connection in toolchain.connections:
313 314
            if connection.end_block == end.block_object and connection.end_pin == end:
                return False
315 316 317

        return True

318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
    def load(self, toolchain, connection_details, channel_colors):

        self.start_block_name = connection_details["from"].split(".")[0]
        self.start_pin_name = connection_details["from"].split(".")[1]
        self.end_block_name = connection_details["to"].split(".")[0]
        self.end_pin_name = connection_details["to"].split(".")[1]
        self.channel = connection_details["channel"]

        hexadecimal = channel_colors[self.channel].lstrip("#")
        hlen = len(hexadecimal)
        self.connection_color = list(
            int(hexadecimal[i : i + hlen // 3], 16) for i in range(0, hlen, hlen // 3)
        )
        self.connection_color.append(255)
        self.connection_pen.setColor(QColor(*self.connection_color))

        self.blocks = toolchain.blocks

        for block in self.blocks:
            if block.name == self.start_block_name:
                self.start_block = block
            elif block.name == self.end_block_name:
                self.end_block = block

        self.start_pin = self.start_block.pins["outputs"][self.start_pin_name]
        self.end_pin = self.end_block.pins["inputs"][self.end_pin_name]

        self.start_pin_center_point = self.start_pin.get_center_point()
        self.end_pin_center_point = self.end_pin.get_center_point()

        self.drawCubicBezierCurve()

        self.start_block.blockMoved.connect(self.set_moved_block_pins_coordinates)
        self.end_block.blockMoved.connect(self.set_moved_block_pins_coordinates)

353 354 355 356 357 358 359 360

class BlockType(Enum):
    """All possible block types"""

    BLOCKS = "blocks"
    ANALYZERS = "analyzers"
    DATASETS = "datasets"

361 362 363 364 365 366 367
    @classmethod
    def from_name(cls, name):
        try:
            return cls[name]
        except ValueError:
            raise ValueError("{} is not a valid block type".format(name))

368 369 370 371

class Block(QGraphicsObject):
    """Block item"""

372
    dataChanged = pyqtSignal()
373 374
    blockMoved = pyqtSignal()

375
    def __init__(self, block_type, style, connection_style):
376 377 378
        super().__init__()

        # Block information
379 380
        self.type = block_type
        self.name = ""
381

382
        if self.type == BlockType.DATASETS.name:
383 384
            self.inputs = None
        else:
385 386 387
            self.inputs = []

        if self.type == BlockType.ANALYZERS.name:
388 389
            self.outputs = None
        else:
390 391 392
            self.outputs = []

        self.synchronized_channel = None
393 394

        self.style = style
395
        self.connection_style = connection_style
396

397
        self.position = QPointF(0, 0)
398 399 400 401 402
        self.pins = dict()
        self.pins["inputs"] = dict()
        self.pins["outputs"] = dict()

        self.set_style(style)
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417

    def load(self, toolchain, block_details):
        self.toolchain = toolchain
        self.name = block_details["name"]

        if self.type != BlockType.DATASETS.name:
            self.inputs = block_details["inputs"]

        if self.type != BlockType.ANALYZERS.name:
            self.outputs = block_details["outputs"]

        if "synchronized_channel" in block_details:
            self.synchronized_channel = block_details["synchronized_channel"]

        self.set_style(self.style)
418 419 420 421 422 423 424 425 426 427 428
        self.create_pins()

    def create_pins(self):

        if self.inputs is not None:
            for pin_name in self.inputs:
                input_pin = InputPin(
                    self, pin_name, self.name, self.pin_brush, self.pin_pen
                )
                self.pins["inputs"][pin_name] = input_pin

429 430
                input_pin.dataChanged.connect(self.dataChanged)

431 432 433 434 435 436 437
        if self.outputs is not None:
            for pin_name in self.outputs:
                output_pin = OutputPin(
                    self, pin_name, self.name, self.pin_brush, self.pin_pen
                )
                self.pins["outputs"][pin_name] = output_pin

438 439
                output_pin.dataChanged.connect(self.dataChanged)

440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
    def set_style(self, config):

        self.setAcceptHoverEvents(True)
        self.setFlag(QGraphicsItem.ItemIsSelectable, True)
        self.setFlag(QGraphicsItem.ItemIsMovable)

        # Geometry settings
        self.width = config["width"]
        self.height = config["height"]
        self.border = config["border"]
        self.radius = config["radius"]
        self.pin_height = config["pin_height"]

        self.text_font = QFont(config["font"], config["font_size"], QFont.Bold)
        self.pin_font = QFont(config["pin_font"], config["pin_font_size"], QFont.Normal)

        metrics = QFontMetrics(self.text_font)
        text_width = metrics.boundingRect(self.name).width() + 14

459
        if self.inputs is not None and len(self.inputs) > 0:
460 461 462 463 464 465
            self.max_inputs_width = (
                metrics.boundingRect(max(self.inputs, key=len)).width() + 14
            )
        else:
            self.max_inputs_width = 14

466
        if self.outputs is not None and len(self.outputs) > 0:
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482
            self.max_outputs_width = (
                metrics.boundingRect(max(self.outputs, key=len)).width() + 14
            )
        else:
            self.max_outputs_width = 14

        self.custom_width = max(
            self.max_outputs_width + self.max_inputs_width, text_width
        )

        self.center = QPointF()
        self.center.setX(self.custom_width / 2.0)
        self.center.setY(self.height / 2.0)

        self.background_brush = QBrush()
        self.background_brush.setStyle(Qt.SolidPattern)
483 484 485 486 487 488

        self.background_color_datasets = QColor(*config["background_color_datasets"])
        self.background_color_analyzers = QColor(*config["background_color_analyzers"])
        self.background_color_blocks = QColor(*config["background_color_blocks"])

        self.background_brush.setColor(self.background_color_blocks)
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580

        self.background_pen = QPen()
        self.background_pen.setStyle(Qt.SolidLine)
        self.background_pen.setWidth(0)
        self.background_pen.setColor(QColor(*config["background_color"]))

        self.border_pen = QPen()
        self.border_pen.setStyle(Qt.SolidLine)
        self.border_pen.setWidth(self.border)
        self.border_pen.setColor(QColor(*config["border_color"]))

        self.selection_border_pen = QPen()
        self.selection_border_pen.setStyle(Qt.SolidLine)
        self.selection_border_pen.setWidth(self.border)
        self.selection_border_pen.setColor(QColor(*config["selection_border_color"]))

        self.text_pen = QPen()
        self.text_pen.setStyle(Qt.SolidLine)
        self.text_pen.setColor(QColor(*config["text_color"]))

        self._pin_brush = QBrush()
        self._pin_brush.setStyle(Qt.SolidPattern)

        self.pin_pen = QPen()
        self.pin_pen.setStyle(Qt.SolidLine)

        self.pin_brush = QBrush()
        self.pin_brush.setStyle(Qt.SolidPattern)
        self.pin_brush.setColor(QColor(*config["pin_color"]))

    def boundingRect(self):
        """Bounding rect of the block object width by height"""
        metrics = QFontMetrics(self.text_font)
        text_height = metrics.boundingRect(self.name).height() + 14

        if self.inputs is not None and self.outputs is not None:
            max_pin_height = max(len(self.inputs), len(self.outputs))
        elif self.inputs is not None and self.outputs is None:
            max_pin_height = len(self.inputs)
        elif self.inputs is None and self.outputs is not None:
            max_pin_height = len(self.outputs)
        else:
            max_pin_height = 0

        rect = QRect(
            0,
            -text_height,
            self.custom_width,
            text_height + self.height + max_pin_height * self.pin_height,
        )
        rect = QRectF(rect)
        return rect

    def draw_pins_name(self, painter, _type, data):
        """Paint pin with name"""
        offset = 0
        for pin_name in data:
            # Pin rect
            painter.setBrush(self.background_brush)
            painter.setPen(self.background_pen)
            painter.setFont(self.pin_font)

            coord_x = self.border / 2
            alignement = Qt.AlignLeft

            max_width = self.max_inputs_width
            if _type == "output":
                coord_x = self.custom_width - self.max_outputs_width - self.border / 2
                max_width = self.max_outputs_width
                alignement = Qt.AlignRight

            rect = QRect(
                coord_x, self.height - self.radius + offset, max_width, self.pin_height
            )

            textRect = QRect(
                rect.left() + self.radius,
                rect.top() + self.radius,
                rect.width() - 2 * self.radius,
                rect.height(),
            )

            painter.setPen(self.pin_pen)
            painter.drawText(textRect, alignement, pin_name)

            offset += self.pin_height

    def mouseMoveEvent(self, event):
        """Update connections due to new block position"""

        super(Block, self).mouseMoveEvent(event)

581 582
        self.position = self.scenePos()

583
        self.blockMoved.emit()
584
        self.dataChanged.emit()
585 586 587 588 589

    def paint(self, painter, option, widget):
        """Paint the block"""

        # Design tools
590 591 592 593 594
        if self.type == BlockType.DATASETS.name:
            self.background_brush.setColor(self.background_color_datasets)
        elif self.type == BlockType.ANALYZERS.name:
            self.background_brush.setColor(self.background_color_analyzers)

595 596 597 598 599 600 601 602 603 604 605 606
        painter.setBrush(self.background_brush)
        painter.setPen(self.border_pen)

        if self.inputs is not None and self.outputs is not None:
            max_pin_height = max(len(self.inputs), len(self.outputs))
        elif self.inputs is not None and self.outputs is None:
            max_pin_height = len(self.inputs)
        elif self.inputs is None and self.outputs is not None:
            max_pin_height = len(self.outputs)
        else:
            max_pin_height = 0

607 608 609 610 611 612 613
        if self.isSelected():
            self.selection_border_pen.setWidth(3)
            painter.setPen(self.selection_border_pen)
        else:
            self.border_pen.setWidth(0)
            painter.setPen(self.border_pen)

614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641
        painter.drawRoundedRect(
            0,
            0,
            self.custom_width,
            self.height + max_pin_height * self.pin_height,
            self.radius,
            self.radius,
        )

        # Block name
        painter.setPen(self.text_pen)
        painter.setFont(self.text_font)

        metrics = QFontMetrics(painter.font())
        text_width = metrics.boundingRect(self.name).width() + 14
        text_height = metrics.boundingRect(self.name).height() + 14
        margin = (text_width - self.custom_width) * 0.5
        text_rect = QRect(-margin, -text_height, text_width, text_height)

        painter.drawText(text_rect, Qt.AlignCenter, self.name)

        # Pin
        if self.inputs is not None:
            self.draw_pins_name(painter, "input", self.inputs)
        if self.outputs is not None:
            self.draw_pins_name(painter, "output", self.outputs)


642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
class ToolchainView(QGraphicsView):
    def __init__(self, toolchain):
        super().__init__()

        self.toolchain = toolchain

    def wheelEvent(self, event):
        """In/Out zoom view using the mouse wheel"""

        self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
        factor = (event.angleDelta().y() / 120) * 0.1
        self.scale(1 + factor, 1 + factor)

    def keyPressEvent(self, event):
        """Focus on the toolchain when F key pressed"""
        if event.key() == Qt.Key_F:
            self.custom_focus()

    def custom_focus(self):
        """Custom focus on toolchain"""

        selected_blocks = self.scene().selectedItems()
        if selected_blocks:
            x_list = []
            y_list = []
            width_list = []
            height_list = []
            for block in selected_blocks:
                x_list.append(block.scenePos().x())
                y_list.append(block.scenePos().y())
                width_list.append(block.boundingRect().width())
                height_list.append(block.boundingRect().height())
                min_x = min(x_list)
                min_y = min(y_list) + block.boundingRect().y()
                max_width = max(x_list) + max(width_list) - min_x
                max_height = max(y_list) + max(height_list) - min_y
            rect = QRectF(QRect(min_x, min_y, max_width, max_height))
            toolchain_focus = rect
        else:
            toolchain_focus = self.scene().itemsBoundingRect()

        self.fitInView(toolchain_focus, Qt.KeepAspectRatio)


686 687
class Toolchain(QWidget):
    """Toolchain designer"""
688

689 690
    dataChanged = pyqtSignal()

691 692 693 694 695
    def __init__(self, parent=None):
        super().__init__(parent=parent)

        self.json_object = {}

696 697 698 699 700 701 702 703 704
        with open("beat/editor/widgets/space_nodes_config.json") as json_file:
            config_data = json.load(json_file)

        self.scene_config = config_data["drawing_space_config"]
        self.scene = DrawingSpace(self.scene_config)

        self.block_config = config_data["block_config"]
        self.connection_config = config_data["connection_config"]

705 706
        self.view = ToolchainView(self)
        self.view.setScene(self.scene)
707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730

        self.toolbar = QToolBar()

        dataset_action = QAction(
            QIcon("beat/editor/widgets/dataset_icon.png"), "&Dataset", self
        )
        dataset_action.triggered.connect(lambda: self.add_block(BlockType.DATASETS))
        block_action = QAction(
            QIcon("beat/editor/widgets/block_icon.png"), "&Block", self
        )
        block_action.triggered.connect(lambda: self.add_block(BlockType.BLOCKS))
        analyzer_action = QAction(
            QIcon("beat/editor/widgets/analyzer_icon.png"), "&Analyzer", self
        )
        analyzer_action.triggered.connect(lambda: self.add_block(BlockType.ANALYZERS))

        self.toolbar.addAction(dataset_action)
        self.toolbar.addAction(block_action)
        self.toolbar.addAction(analyzer_action)

        self.toolbar.setOrientation(Qt.Vertical)

        layout = QHBoxLayout(self)
        layout.addWidget(self.toolbar)
731
        layout.addWidget(self.view)
732

733 734 735 736 737 738
    def add_block(self, block_type):
        self.new_block = Block(
            block_type.name, self.block_config, self.connection_config
        )
        self.scene.addItem(self.new_block)

739 740 741 742
    def clear_space(self):
        self.scene.clear()
        self.scene.items().clear()
        self.blocks = []
743
        self.connections = []
744
        self.channels = []
745 746 747 748 749 750

    def load(self, json_object):
        """Parse the json in parameter and generates a graph"""

        self.json_object = json_object

751 752 753 754 755
        if "representation" in self.json_object:
            self.web_representation = self.json_object["representation"]
        else:
            self.web_representation = None

756 757 758 759 760
        if "editor_gui" in self.json_object:
            self.editor_gui = self.json_object["editor_gui"]
        else:
            self.editor_gui = None

761
        self.clear_space()
762

763 764 765
        # Get datasets, blocks, analyzers
        for block_type in BlockType:
            for block_item in self.json_object[block_type.value]:
766
                block = Block(
767
                    block_type.name, self.block_config, self.connection_config
768
                )
769
                block.load(self, block_item)
770 771 772 773 774 775 776 777
                # Place blocks (x,y) if information is given
                if self.editor_gui is not None:
                    if block.name in self.editor_gui:
                        block.setPos(
                            self.editor_gui[block.name]["x"],
                            self.editor_gui[block.name]["y"],
                        )
                        block.position = block.scenePos()
778
                block.dataChanged.connect(self.dataChanged)
779 780 781 782 783
                self.blocks.append(block)
                self.scene.addItem(block)

        # Display connections
        connections = self.json_object["connections"]
784 785
        channel_colors = self.json_object["representation"]["channel_colors"]

786
        for connection_item in connections:
787 788
            connection = Connection(self.connection_config)
            connection.load(self, connection_item, channel_colors)
789
            self.connections.append(connection)
790
            self.scene.addItem(connection)
791 792 793 794

    def dump(self):
        """Returns the json used to load the widget"""

795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836
        data = {}

        if self.web_representation is not None:
            data["representation"] = self.web_representation

        data["editor_gui"] = {}

        for block_type in BlockType:
            block_type_list = []
            for block in self.blocks:
                block_data = {}
                if block_type == BlockType.from_name(block.type):
                    block_data["name"] = block.name
                    if block.synchronized_channel is not None:
                        block_data["synchronized_channel"] = block.synchronized_channel
                    if block.inputs is not None:
                        block_data["inputs"] = block.inputs
                    if block.outputs is not None:
                        block_data["outputs"] = block.outputs

                    block_type_list.append(block_data)
                data["editor_gui"][block.name] = {
                    "x": block.position.x(),
                    "y": block.position.y(),
                }
            data[block_type.value] = block_type_list

        connection_list = []
        for connection in self.connections:
            connection_data = {}
            connection_data["channel"] = connection.channel
            connection_data["from"] = (
                connection.start_block_name + "." + connection.start_pin_name
            )
            connection_data["to"] = (
                connection.end_block_name + "." + connection.end_pin_name
            )
            connection_list.append(connection_data)

        data["connections"] = connection_list

        return data
837 838


Samuel GAIST's avatar
Samuel GAIST committed
839
@frozen
840 841
class ToolchainEditor(AbstractAssetEditor):
    def __init__(self, parent=None):
842
        super().__init__(AssetType.TOOLCHAIN, parent)
Samuel GAIST's avatar
Samuel GAIST committed
843
        self.setObjectName(self.__class__.__name__)
844 845
        self.set_title(self.tr("Toolchain"))

846 847 848 849 850
        self.toolchain_model = AssetModel()
        self.toolchain_model.asset_type = AssetType.TOOLCHAIN

        self.toolchain = Toolchain()
        self.layout().addWidget(self.toolchain, 2)
851 852
        self.layout().addStretch()

853 854
        self.toolchain.dataChanged.connect(self.dataChanged)

855 856
    def _load_json(self, json_object):
        """Load the json object passed as parameter"""
857

858
        self.toolchain.load(json_object)
859 860 861

    def _dump_json(self):
        """Returns the json representation of the asset"""
862

863
        return self.toolchain.dump()