Commit 2be8c0cb authored by Samuel GAIST's avatar Samuel GAIST Committed by Flavio TARSETTI
Browse files

[widgets][toolchaineditor] Fix BlockType enum uses

parent 53ec2269
Pipeline #40521 passed with stage
in 11 minutes and 21 seconds
......@@ -28,6 +28,7 @@ from functools import partial
import simplejson as json
from beat.backend.python.algorithm import Algorithm
from PyQt5.QtCore import QFile
from PyQt5.QtCore import QPointF
from PyQt5.QtCore import QRect
......@@ -64,8 +65,6 @@ from PyQt5.QtWidgets import QToolBar
from PyQt5.QtWidgets import QVBoxLayout
from PyQt5.QtWidgets import QWidget
from beat.backend.python.algorithm import Algorithm
from ..backend.asset import Asset
from ..backend.asset import AssetType
from ..backend.assetmodel import AssetModel
......@@ -159,7 +158,7 @@ class BasePin(QGraphicsObject):
# Find the corresponding channel
connection_settings = {}
if start.block_object.type == BlockType.DATASETS.name:
if start.block_object.type == BlockType.DATASETS:
connection_settings["channel"] = start.block_object.name
else:
connection_settings[
......@@ -212,7 +211,7 @@ class InputPin(BasePin):
x = -(width / 2.0)
if self.block_object.type == BlockType.LOOPS.name:
if self.block_object.type == BlockType.LOOPS:
_idx = None
if self.pin in self.block_object.processor_inputs:
_idx = self.block_object.processor_inputs.index(self.pin)
......@@ -248,7 +247,7 @@ class OutputPin(BasePin):
width = height
x = self.block_object.custom_width - (width / 2.0)
if self.block_object.type == BlockType.LOOPS.name:
if self.block_object.type == BlockType.LOOPS:
_idx = None
if self.pin in self.block_object.processor_outputs:
_idx = self.block_object.processor_outputs.index(self.pin)
......@@ -419,7 +418,7 @@ class Connection(QGraphicsPathItem):
elif block.name == self.end_block_name:
self.end_block = block
if self.start_block.type == BlockType.LOOPS.name:
if self.start_block.type == BlockType.LOOPS:
if self.start_pin_name in self.start_block.pins["outputs"]["processor"]:
self.start_pin = self.start_block.pins["outputs"]["processor"][
self.start_pin_name
......@@ -431,7 +430,7 @@ class Connection(QGraphicsPathItem):
else:
self.start_pin = self.start_block.pins["outputs"][self.start_pin_name]
if self.end_block.type == BlockType.LOOPS.name:
if self.end_block.type == BlockType.LOOPS:
if self.end_pin_name in self.end_block.pins["inputs"]["processor"]:
self.end_pin = self.end_block.pins["inputs"]["processor"][
self.end_pin_name
......@@ -683,7 +682,7 @@ class BlockEditionDialog(QDialog):
channels = []
if block.type == BlockType.DATASETS.name:
if block.type == BlockType.DATASETS:
channel_color_button = QPushButton("Channel color", self)
channel_color_button.setToolTip("Opens color dialog")
channel_color_button.clicked.connect(self.on_color_click)
......@@ -764,7 +763,7 @@ class BlockType(Enum):
@classmethod
def from_name(cls, name):
try:
return cls[name]
return cls[name.upper()]
except KeyError:
raise KeyError("{} is not a valid block type".format(name))
......@@ -785,18 +784,18 @@ class Block(QGraphicsObject):
self.type = block_type
self.name = ""
if self.type == BlockType.LOOPS.name:
if self.type == BlockType.LOOPS:
self.processor_inputs = []
self.processor_outputs = []
self.evaluator_inputs = []
self.evaluator_outputs = []
else:
if self.type == BlockType.DATASETS.name:
if self.type == BlockType.DATASETS:
self.inputs = None
else:
self.inputs = []
if self.type == BlockType.ANALYZERS.name:
if self.type == BlockType.ANALYZERS:
self.outputs = None
else:
self.outputs = []
......@@ -811,7 +810,7 @@ class Block(QGraphicsObject):
self.pins["inputs"] = dict()
self.pins["outputs"] = dict()
if self.type == BlockType.LOOPS.name:
if self.type == BlockType.LOOPS:
self.pins["inputs"]["processor"] = dict()
self.pins["outputs"]["processor"] = dict()
self.pins["inputs"]["evaluator"] = dict()
......@@ -825,22 +824,16 @@ class Block(QGraphicsObject):
if "name" in block_details:
self.name = block_details["name"]
if self.type == BlockType.LOOPS.name:
if self.type == BlockType.LOOPS:
self.processor_inputs = block_details["processor_inputs"]
self.processor_outputs = block_details["processor_outputs"]
self.evaluator_inputs = block_details["evaluator_inputs"]
self.evaluator_outputs = block_details["evaluator_outputs"]
else:
if (
self.type != BlockType.DATASETS.name
and self.type != BlockType.LOOPS.name
):
if self.type != BlockType.DATASETS and self.type != BlockType.LOOPS:
self.inputs = block_details["inputs"]
if (
self.type != BlockType.ANALYZERS.name
and self.type != BlockType.LOOPS.name
):
if self.type != BlockType.ANALYZERS and self.type != BlockType.LOOPS:
self.outputs = block_details["outputs"]
if "synchronized_channel" in block_details:
......@@ -851,7 +844,7 @@ class Block(QGraphicsObject):
def create_pins(self):
if self.type == BlockType.LOOPS.name:
if self.type == BlockType.LOOPS:
for pin_name in self.processor_inputs:
input_pin = InputPin(
self, pin_name, self.name, self.pin_brush, self.pin_pen
......@@ -921,7 +914,7 @@ class Block(QGraphicsObject):
metrics = QFontMetrics(self.text_font)
text_width = metrics.boundingRect(self.name).width() + 24
if self.type == BlockType.LOOPS.name:
if self.type == BlockType.LOOPS:
if len(self.processor_inputs) > 0:
self.max_processor_inputs_width = (
metrics.boundingRect(max(self.processor_inputs, key=len)).width()
......@@ -1030,7 +1023,7 @@ class Block(QGraphicsObject):
metrics = QFontMetrics(self.text_font)
text_height = metrics.boundingRect(self.name).height() + 24
if self.type == BlockType.LOOPS.name:
if self.type == BlockType.LOOPS:
max_pin_height = max(
len(self.processor_inputs) + len(self.evaluator_inputs),
len(self.processor_outputs) + len(self.evaluator_outputs),
......@@ -1122,7 +1115,7 @@ class Block(QGraphicsObject):
if self.name != value["name"]:
old_name = self.name
self.name = value["name"]
if self.type == BlockType.DATASETS.name:
if self.type == BlockType.DATASETS:
self.toolchain.web_representation["channel_colors"][
self.name
] = self.toolchain.web_representation["channel_colors"].pop(
......@@ -1132,11 +1125,11 @@ class Block(QGraphicsObject):
block_updated = True
if (
self.synchronized_channel != value["channel"]
and self.type != BlockType.DATASETS.name
and self.type != BlockType.DATASETS
):
self.synchronized_channel = value["channel"]
block_updated = True
if self.type == BlockType.DATASETS.name:
if self.type == BlockType.DATASETS:
self.toolchain.web_representation["channel_colors"][self.name] = value[
"color"
]
......@@ -1144,7 +1137,7 @@ class Block(QGraphicsObject):
if block_updated:
block_item = {}
if self.type == BlockType.LOOPS.name:
if self.type == BlockType.LOOPS:
block_item["processor_inputs"] = self.processor_inputs
block_item["processor_outputs"] = self.processor_outputs
block_item["evaluator_inputs"] = self.evaluator_inputs
......@@ -1165,10 +1158,10 @@ class Block(QGraphicsObject):
self.toolchain.scene.addItem(self)
# if type is dataset: update sync channels everywhere
if self.type == BlockType.DATASETS.name:
if self.type == BlockType.DATASETS:
for block in self.toolchain.blocks:
if (
block.type != BlockType.DATASETS.name
block.type != BlockType.DATASETS
and block.synchronized_channel == old_name
):
block.synchronized_channel = self.name
......@@ -1253,17 +1246,17 @@ class Block(QGraphicsObject):
"""Paint the block"""
# Design tools
if self.type == BlockType.DATASETS.name:
if self.type == BlockType.DATASETS:
self.background_brush.setColor(self.background_color_datasets)
elif self.type == BlockType.ANALYZERS.name:
elif self.type == BlockType.ANALYZERS:
self.background_brush.setColor(self.background_color_analyzers)
elif self.type == BlockType.LOOPS.name:
elif self.type == BlockType.LOOPS:
self.background_brush.setColor(self.background_color_loops)
painter.setBrush(self.background_brush)
painter.setPen(self.border_pen)
if self.type == BlockType.LOOPS.name:
if self.type == BlockType.LOOPS:
max_pin_height = max(
len(self.processor_inputs) + len(self.evaluator_inputs),
len(self.processor_outputs) + len(self.evaluator_outputs),
......@@ -1311,7 +1304,7 @@ class Block(QGraphicsObject):
painter.drawText(text_rect, Qt.AlignCenter, self.name)
# Pin
if self.type == BlockType.LOOPS.name:
if self.type == BlockType.LOOPS:
self.draw_pins_name(painter, "input", self.processor_inputs, 0)
self.draw_pins_name(
painter, "input", self.evaluator_inputs, len(self.processor_inputs)
......@@ -1400,7 +1393,7 @@ class ToolchainView(QGraphicsView):
warning.setStandardButtons(QMessageBox.Ok)
warning.exec_()
else:
if item.type == BlockType.DATASETS.name:
if item.type == BlockType.DATASETS:
if (
item.name
in self.toolchain.web_representation["channel_colors"]
......@@ -1592,7 +1585,7 @@ class ToolchainWidget(QWidget):
block_item["name"] = new_block_name
block = Block(
BlockType.LOOPS.name, self.block_config, self.connection_config
BlockType.LOOPS, self.block_config, self.connection_config
)
block.load(self, block_item)
......@@ -1680,7 +1673,7 @@ class ToolchainWidget(QWidget):
new_block_name = init_name + "_" + str(init_name_count)
block_item["name"] = new_block_name
if block_type == BlockType.DATASETS.name:
if block_type == BlockType.DATASETS:
self.web_representation["channel_colors"][new_block_name] = "#000000"
block = Block(block_type, self.block_config, self.connection_config)
......@@ -1739,9 +1732,7 @@ class ToolchainWidget(QWidget):
# Get datasets, blocks, analyzers, loops
for block_type in BlockType:
for block_item in self.json_object.get(block_type.value, {}):
block = Block(
block_type.name, self.block_config, self.connection_config
)
block = Block(block_type, self.block_config, self.connection_config)
block.load(self, block_item)
# Place blocks (x,y) if information is given
if self.web_representation["blocks"] is not None:
......@@ -1782,11 +1773,11 @@ class ToolchainWidget(QWidget):
block_type_list = []
for block in self.blocks:
block_data = {}
if block_type == BlockType.from_name(block.type):
if block.type == block_type:
block_data["name"] = block.name
if block.synchronized_channel is not None:
block_data["synchronized_channel"] = block.synchronized_channel
if block.type == BlockType.LOOPS.name:
if block.type == BlockType.LOOPS:
block_data["processor_inputs"] = block.processor_inputs
block_data["evaluator_inputs"] = block.evaluator_inputs
block_data["processor_outputs"] = block.processor_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment