# vim: set fileencoding=utf-8 : ############################################################################### # # # Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/ # # Contact: beat.support@idiap.ch # # # # This file is part of the beat.editor module of the BEAT platform. # # # # Commercial License Usage # # Licensees holding valid commercial BEAT licenses may use this file in # # accordance with the terms contained in a written agreement between you # # and Idiap. For further information contact tto@idiap.ch # # # # Alternatively, this file may be used under the terms of the GNU Affero # # Public License version 3 as published by the Free Software and appearing # # in the file LICENSE.AGPL included in the packaging of this file. # # The BEAT platform is distributed in the hope that it will be useful, but # # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # # or FITNESS FOR A PARTICULAR PURPOSE. # # # # You should have received a copy of the GNU Affero Public License along # # with the BEAT platform. If not, see http://www.gnu.org/licenses/. # # # ############################################################################### import simplejson as json from enum import Enum from PyQt5.QtCore import Qt from PyQt5.QtCore import QRect from PyQt5.QtCore import QRectF from PyQt5.QtCore import QPointF from PyQt5.QtCore import pyqtSignal from PyQt5.QtGui import QColor from PyQt5.QtGui import QBrush from PyQt5.QtGui import QPen from PyQt5.QtGui import QFont from PyQt5.QtGui import QFontMetrics from PyQt5.QtGui import QPainterPath from PyQt5.QtGui import QTransform from PyQt5.QtWidgets import QVBoxLayout from PyQt5.QtWidgets import QWidget from PyQt5.QtWidgets import QGraphicsView from PyQt5.QtWidgets import QGraphicsItem from PyQt5.QtWidgets import QGraphicsPathItem from PyQt5.QtWidgets import QGraphicsObject from ..backend.asset import AssetType from ..backend.assetmodel import AssetModel from ..decorators import frozen from .editor import AbstractAssetEditor from .drawing_space import DrawingSpace class BasePin(QGraphicsObject): """Base class for pin graphics""" dataChanged = pyqtSignal() def __init__(self, parent, pin, block, pin_brush, pin_pen): super().__init__(parent=parent) # Highlight self.setAcceptHoverEvents(True) # Storage self.pin_type = None self.pin = pin self.block = block self.brush = pin_brush self.pen = pin_pen self.block_object = parent def shape(self): """Define the circle shape of a Pin object""" path = QPainterPath() path.addRect(self.boundingRect()) return path def paint(self, painter, option, widget): """Paint the Pin""" painter.setBrush(self.brush) painter.setPen(self.pen) painter.drawEllipse(self.boundingRect()) def mousePressEvent(self, event): """Painting connection initiated""" self.new_connection = Connection(self.block_object.connection_style) self.block_object.scene().addItem(self.new_connection) def mouseMoveEvent(self, event): """Painting connection in progress""" # Only one single connection allowed from input pin if isinstance(self, InputPin): # Check if connection exist and remove if it does for connection in self.block_object.toolchain.connections: if ( connection.end_block == self.block_object and connection.end_pin == self ): self.block_object.toolchain.connections.remove(connection) self.block_object.scene().removeItem(connection) self.dataChanged.emit() mouse_position = self.mapToScene(event.pos()) self.new_connection.set_new_connection_pins_coordinates(self, mouse_position) def mouseReleaseEvent(self, event): """Painting connection ended - validation required""" self.block_object.scene().removeItem(self.new_connection) target = self.block_object.scene().itemAt( event.scenePos().toPoint(), QTransform() ) if isinstance(target, BasePin): if isinstance(self, OutputPin): start = self end = target else: start = target end = self if Connection(self.block_object.connection_style).check_validity( start, end ): # Find the corresponding channel connection_settings = {} if start.block_object.type == BlockType.DATASETS.name: connection_settings["channel"] = start.block_object.name else: connection_settings[ "channel" ] = start.block_object.synchronized_channel # Create the connection connection_settings["from"] = start.block + "." + start.pin connection_settings["to"] = end.block + "." + end.pin channel_colors = self.block_object.toolchain.json_object[ "representation" ]["channel_colors"] connection = Connection(self.block_object.connection_style) connection.load( self.block_object.toolchain, connection_settings, channel_colors ) self.dataChanged.emit() self.block_object.toolchain.connections.append(connection) self.block_object.toolchain.scene.addItem(connection) def get_center_point(self): """Get the center coordinates of the Pin(x,y)""" rect = self.boundingRect() pin_center_point = QPointF( rect.x() + rect.width() / 2.0, rect.y() + rect.height() / 2.0 ) return self.mapToScene(pin_center_point) class InputPin(BasePin): def __init__(self, parent, pin, block, pin_brush, pin_pen): super().__init__(parent, pin, block, pin_brush, pin_pen) def boundingRect(self): """Bounding rect around pin object""" height = self.block_object.height / 2.0 width = height x = -(width / 2.0) y = ( self.block_object.height + self.block_object.inputs.index(self.pin) * self.block_object.height ) rect = QRectF(QRect(x, y, width, height)) return rect class OutputPin(BasePin): def __init__(self, parent, pin, block, pin_brush, pin_pen): super().__init__(parent, pin, block, pin_brush, pin_pen) def boundingRect(self): """ bounding rect width by height. """ height = self.block_object.height / 2.0 width = height x = self.block_object.custom_width - (width / 2.0) y = ( self.block_object.height + self.block_object.outputs.index(self.pin) * self.block_object.height ) rect = QRectF(QRect(x, y, width, height)) return rect class Connection(QGraphicsPathItem): def __init__(self, style): super().__init__() self.start_block_name = None self.start_pin_name = None self.start_pin_center_point = None self.end_block_name = None self.end_pin_name = None self.end_pin_center_point = None self.channel = None self.connection_color = [] self.set_style(style) def set_style(self, config): # Highlight self.setAcceptHoverEvents(True) # Geometry and color settings self.connection_color = config["color"] self.connection_pen = QPen() self.connection_pen.setColor(QColor(*self.connection_color)) self.connection_pen.setWidth(config["width"]) def drawCubicBezierCurve(self): self.setPen(self.connection_pen) path = QPainterPath() middle_point_x = ( self.end_pin_center_point.x() - self.start_pin_center_point.x() ) / 2.0 middle_point_y = ( self.end_pin_center_point.y() - self.start_pin_center_point.y() ) / 2.0 second_middle_point_y = ( self.end_pin_center_point.y() - self.start_pin_center_point.y() ) / 4.0 control_point = QPointF(middle_point_x, middle_point_y) second_control_point = QPointF(middle_point_x, second_middle_point_y) path.moveTo(self.start_pin_center_point) path.cubicTo( self.start_pin_center_point + control_point, self.end_pin_center_point - second_control_point, self.end_pin_center_point, ) self.setPath(path) def set_moved_block_pins_coordinates(self): self.start_pin_center_point = self.start_pin.get_center_point() self.end_pin_center_point = self.end_pin.get_center_point() self.drawCubicBezierCurve() def set_new_connection_pins_coordinates(self, selected_pin, mouse_position): if isinstance(selected_pin, OutputPin): self.start_block_name = selected_pin.block self.start_pin_name = selected_pin.pin self.start_pin = selected_pin self.start_pin_center_point = self.start_pin.get_center_point() self.end_pin_center_point = mouse_position if isinstance(selected_pin, InputPin): self.end_block_name = selected_pin.block self.end_pin_name = selected_pin.pin self.end_pin = selected_pin self.end_pin_center_point = self.end_pin.get_center_point() self.start_pin_center_point = mouse_position self.drawCubicBezierCurve() def check_validity(self, start, end): # remove input-input and output-output connection if type(start) == type(end): return False # remove connection to same block object if start.block_object == end.block_object: return False # check if end block pin is free toolchain = end.block_object.toolchain for connection in toolchain.connections: if connection.end_block == end.block_object and connection.end_pin == end: return False return True def load(self, toolchain, connection_details, channel_colors): self.start_block_name = connection_details["from"].split(".")[0] self.start_pin_name = connection_details["from"].split(".")[1] self.end_block_name = connection_details["to"].split(".")[0] self.end_pin_name = connection_details["to"].split(".")[1] self.channel = connection_details["channel"] hexadecimal = channel_colors[self.channel].lstrip("#") hlen = len(hexadecimal) self.connection_color = list( int(hexadecimal[i : i + hlen // 3], 16) for i in range(0, hlen, hlen // 3) ) self.connection_color.append(255) self.connection_pen.setColor(QColor(*self.connection_color)) self.blocks = toolchain.blocks for block in self.blocks: if block.name == self.start_block_name: self.start_block = block elif block.name == self.end_block_name: self.end_block = block self.start_pin = self.start_block.pins["outputs"][self.start_pin_name] self.end_pin = self.end_block.pins["inputs"][self.end_pin_name] self.start_pin_center_point = self.start_pin.get_center_point() self.end_pin_center_point = self.end_pin.get_center_point() self.drawCubicBezierCurve() self.start_block.blockMoved.connect(self.set_moved_block_pins_coordinates) self.end_block.blockMoved.connect(self.set_moved_block_pins_coordinates) class BlockType(Enum): """All possible block types""" BLOCKS = "blocks" ANALYZERS = "analyzers" DATASETS = "datasets" @classmethod def from_name(cls, name): try: return cls[name] except ValueError: raise ValueError("{} is not a valid block type".format(name)) class Block(QGraphicsObject): """Block item""" dataChanged = pyqtSignal() blockMoved = pyqtSignal() def __init__(self, toolchain, block_details, block_type, style, connection_style): super().__init__() # Block information self.toolchain = toolchain self.name = block_details["name"] if block_type == BlockType.DATASETS.name: self.inputs = None else: self.inputs = block_details["inputs"] if block_type == BlockType.ANALYZERS.name: self.outputs = None else: self.outputs = block_details["outputs"] if "synchronized_channel" in block_details: self.synchronized_channel = block_details["synchronized_channel"] else: self.synchronized_channel = None self.type = block_type self.style = style self.connection_style = connection_style self.position = QPointF(0, 0) self.pins = dict() self.pins["inputs"] = dict() self.pins["outputs"] = dict() self.set_style(style) self.create_pins() def create_pins(self): if self.inputs is not None: for pin_name in self.inputs: input_pin = InputPin( self, pin_name, self.name, self.pin_brush, self.pin_pen ) self.pins["inputs"][pin_name] = input_pin input_pin.dataChanged.connect(self.dataChanged) if self.outputs is not None: for pin_name in self.outputs: output_pin = OutputPin( self, pin_name, self.name, self.pin_brush, self.pin_pen ) self.pins["outputs"][pin_name] = output_pin output_pin.dataChanged.connect(self.dataChanged) def set_style(self, config): self.setAcceptHoverEvents(True) self.setFlag(QGraphicsItem.ItemIsSelectable, True) self.setFlag(QGraphicsItem.ItemIsMovable) # Geometry settings self.width = config["width"] self.height = config["height"] self.border = config["border"] self.radius = config["radius"] self.pin_height = config["pin_height"] self.text_font = QFont(config["font"], config["font_size"], QFont.Bold) self.pin_font = QFont(config["pin_font"], config["pin_font_size"], QFont.Normal) metrics = QFontMetrics(self.text_font) text_width = metrics.boundingRect(self.name).width() + 14 if self.inputs is not None: self.max_inputs_width = ( metrics.boundingRect(max(self.inputs, key=len)).width() + 14 ) else: self.max_inputs_width = 14 if self.outputs is not None: self.max_outputs_width = ( metrics.boundingRect(max(self.outputs, key=len)).width() + 14 ) else: self.max_outputs_width = 14 self.custom_width = max( self.max_outputs_width + self.max_inputs_width, text_width ) self.center = QPointF() self.center.setX(self.custom_width / 2.0) self.center.setY(self.height / 2.0) self.background_brush = QBrush() self.background_brush.setStyle(Qt.SolidPattern) self.background_color_datasets = QColor(*config["background_color_datasets"]) self.background_color_analyzers = QColor(*config["background_color_analyzers"]) self.background_color_blocks = QColor(*config["background_color_blocks"]) self.background_brush.setColor(self.background_color_blocks) self.background_pen = QPen() self.background_pen.setStyle(Qt.SolidLine) self.background_pen.setWidth(0) self.background_pen.setColor(QColor(*config["background_color"])) self.border_pen = QPen() self.border_pen.setStyle(Qt.SolidLine) self.border_pen.setWidth(self.border) self.border_pen.setColor(QColor(*config["border_color"])) self.selection_border_pen = QPen() self.selection_border_pen.setStyle(Qt.SolidLine) self.selection_border_pen.setWidth(self.border) self.selection_border_pen.setColor(QColor(*config["selection_border_color"])) self.text_pen = QPen() self.text_pen.setStyle(Qt.SolidLine) self.text_pen.setColor(QColor(*config["text_color"])) self._pin_brush = QBrush() self._pin_brush.setStyle(Qt.SolidPattern) self.pin_pen = QPen() self.pin_pen.setStyle(Qt.SolidLine) self.pin_brush = QBrush() self.pin_brush.setStyle(Qt.SolidPattern) self.pin_brush.setColor(QColor(*config["pin_color"])) def boundingRect(self): """Bounding rect of the block object width by height""" metrics = QFontMetrics(self.text_font) text_height = metrics.boundingRect(self.name).height() + 14 if self.inputs is not None and self.outputs is not None: max_pin_height = max(len(self.inputs), len(self.outputs)) elif self.inputs is not None and self.outputs is None: max_pin_height = len(self.inputs) elif self.inputs is None and self.outputs is not None: max_pin_height = len(self.outputs) else: max_pin_height = 0 rect = QRect( 0, -text_height, self.custom_width, text_height + self.height + max_pin_height * self.pin_height, ) rect = QRectF(rect) return rect def draw_pins_name(self, painter, _type, data): """Paint pin with name""" offset = 0 for pin_name in data: # Pin rect painter.setBrush(self.background_brush) painter.setPen(self.background_pen) painter.setFont(self.pin_font) coord_x = self.border / 2 alignement = Qt.AlignLeft max_width = self.max_inputs_width if _type == "output": coord_x = self.custom_width - self.max_outputs_width - self.border / 2 max_width = self.max_outputs_width alignement = Qt.AlignRight rect = QRect( coord_x, self.height - self.radius + offset, max_width, self.pin_height ) textRect = QRect( rect.left() + self.radius, rect.top() + self.radius, rect.width() - 2 * self.radius, rect.height(), ) painter.setPen(self.pin_pen) painter.drawText(textRect, alignement, pin_name) offset += self.pin_height def mouseMoveEvent(self, event): """Update connections due to new block position""" super(Block, self).mouseMoveEvent(event) self.position = self.scenePos() self.blockMoved.emit() self.dataChanged.emit() def paint(self, painter, option, widget): """Paint the block""" # Design tools if self.type == BlockType.DATASETS.name: self.background_brush.setColor(self.background_color_datasets) elif self.type == BlockType.ANALYZERS.name: self.background_brush.setColor(self.background_color_analyzers) painter.setBrush(self.background_brush) painter.setPen(self.border_pen) if self.inputs is not None and self.outputs is not None: max_pin_height = max(len(self.inputs), len(self.outputs)) elif self.inputs is not None and self.outputs is None: max_pin_height = len(self.inputs) elif self.inputs is None and self.outputs is not None: max_pin_height = len(self.outputs) else: max_pin_height = 0 if self.isSelected(): self.selection_border_pen.setWidth(3) painter.setPen(self.selection_border_pen) else: self.border_pen.setWidth(0) painter.setPen(self.border_pen) painter.drawRoundedRect( 0, 0, self.custom_width, self.height + max_pin_height * self.pin_height, self.radius, self.radius, ) # Block name painter.setPen(self.text_pen) painter.setFont(self.text_font) metrics = QFontMetrics(painter.font()) text_width = metrics.boundingRect(self.name).width() + 14 text_height = metrics.boundingRect(self.name).height() + 14 margin = (text_width - self.custom_width) * 0.5 text_rect = QRect(-margin, -text_height, text_width, text_height) painter.drawText(text_rect, Qt.AlignCenter, self.name) # Pin if self.inputs is not None: self.draw_pins_name(painter, "input", self.inputs) if self.outputs is not None: self.draw_pins_name(painter, "output", self.outputs) class ToolchainView(QGraphicsView): def __init__(self, toolchain): super().__init__() self.toolchain = toolchain def wheelEvent(self, event): """In/Out zoom view using the mouse wheel""" self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse) factor = (event.angleDelta().y() / 120) * 0.1 self.scale(1 + factor, 1 + factor) def keyPressEvent(self, event): """Focus on the toolchain when F key pressed""" if event.key() == Qt.Key_F: self.custom_focus() def custom_focus(self): """Custom focus on toolchain""" selected_blocks = self.scene().selectedItems() if selected_blocks: x_list = [] y_list = [] width_list = [] height_list = [] for block in selected_blocks: x_list.append(block.scenePos().x()) y_list.append(block.scenePos().y()) width_list.append(block.boundingRect().width()) height_list.append(block.boundingRect().height()) min_x = min(x_list) min_y = min(y_list) + block.boundingRect().y() max_width = max(x_list) + max(width_list) - min_x max_height = max(y_list) + max(height_list) - min_y rect = QRectF(QRect(min_x, min_y, max_width, max_height)) toolchain_focus = rect else: toolchain_focus = self.scene().itemsBoundingRect() self.fitInView(toolchain_focus, Qt.KeepAspectRatio) class Toolchain(QWidget): """Toolchain designer""" dataChanged = pyqtSignal() def __init__(self, parent=None): super().__init__(parent=parent) self.json_object = {} with open("beat/editor/widgets/space_nodes_config.json") as json_file: config_data = json.load(json_file) self.scene_config = config_data["drawing_space_config"] self.scene = DrawingSpace(self.scene_config) self.block_config = config_data["block_config"] self.connection_config = config_data["connection_config"] self.view = ToolchainView(self) self.view.setScene(self.scene) layout = QVBoxLayout(self) layout.addWidget(self.view) def clear_space(self): self.scene.clear() self.scene.items().clear() self.blocks = [] self.connections = [] self.channels = [] def load(self, json_object): """Parse the json in parameter and generates a graph""" self.json_object = json_object if "representation" in self.json_object: self.web_representation = self.json_object["representation"] else: self.web_representation = None if "editor_gui" in self.json_object: self.editor_gui = self.json_object["editor_gui"] else: self.editor_gui = None self.clear_space() # Get datasets, blocks, analyzers for block_type in BlockType: for block_item in self.json_object[block_type.value]: block = Block( self, block_item, block_type.name, self.block_config, self.connection_config, ) # Place blocks (x,y) if information is given if self.editor_gui is not None: if block.name in self.editor_gui: block.setPos( self.editor_gui[block.name]["x"], self.editor_gui[block.name]["y"], ) block.position = block.scenePos() block.dataChanged.connect(self.dataChanged) self.blocks.append(block) self.scene.addItem(block) # Display connections connections = self.json_object["connections"] channel_colors = self.json_object["representation"]["channel_colors"] for connection_item in connections: connection = Connection(self.connection_config) connection.load(self, connection_item, channel_colors) self.connections.append(connection) self.scene.addItem(connection) def dump(self): """Returns the json used to load the widget""" data = {} if self.web_representation is not None: data["representation"] = self.web_representation data["editor_gui"] = {} for block_type in BlockType: block_type_list = [] for block in self.blocks: block_data = {} if block_type == BlockType.from_name(block.type): block_data["name"] = block.name if block.synchronized_channel is not None: block_data["synchronized_channel"] = block.synchronized_channel if block.inputs is not None: block_data["inputs"] = block.inputs if block.outputs is not None: block_data["outputs"] = block.outputs block_type_list.append(block_data) data["editor_gui"][block.name] = { "x": block.position.x(), "y": block.position.y(), } data[block_type.value] = block_type_list connection_list = [] for connection in self.connections: connection_data = {} connection_data["channel"] = connection.channel connection_data["from"] = ( connection.start_block_name + "." + connection.start_pin_name ) connection_data["to"] = ( connection.end_block_name + "." + connection.end_pin_name ) connection_list.append(connection_data) data["connections"] = connection_list return data @frozen class ToolchainEditor(AbstractAssetEditor): def __init__(self, parent=None): super().__init__(AssetType.TOOLCHAIN, parent) self.setObjectName(self.__class__.__name__) self.set_title(self.tr("Toolchain")) self.toolchain_model = AssetModel() self.toolchain_model.asset_type = AssetType.TOOLCHAIN self.toolchain = Toolchain() self.layout().addWidget(self.toolchain, 2) self.layout().addStretch() self.toolchain.dataChanged.connect(self.dataChanged) def _load_json(self, json_object): """Load the json object passed as parameter""" self.toolchain.load(json_object) def _dump_json(self): """Returns the json representation of the asset""" return self.toolchain.dump()