Commit 83946657 authored by Samuel GAIST's avatar Samuel GAIST
Browse files

[widgets][experimenteditor] Refactor algorithm handling using new prefix modelisation

This also start the implementation for support "wrong"
experiments as in not yet fully configured.

It also now does updating of the IO mapping when selecting
a new algorithm.
parent 6cda0bf9
......@@ -23,7 +23,6 @@
# #
###############################################################################
import re
import copy
import pytest
......@@ -58,7 +57,6 @@ from ..widgets.experimenteditor import AnalyzerBlockEditor
from ..widgets.experimenteditor import LoopBlockEditor
from ..widgets.experimenteditor import GlobalParametersEditor
from ..widgets.experimenteditor import EnvironmentModel
from ..widgets.experimenteditor import FieldPresenceFilterProxyModel
from ..widgets.experimenteditor import ExperimentEditor
from ..widgets.experimenteditor import typed_user_property
......@@ -501,58 +499,6 @@ class TestAlgorithmParametersEditor:
assert widget.itemText(i) in choices
class TestFieldPresenceFilter:
"""Test that the field presence filter works as expected"""
@pytest.mark.parametrize("must_be_present", [True, False])
@pytest.mark.parametrize("field_name", ["results", "type"])
def test_field_presence(self, algorithm_model, field_name, must_be_present):
filter_model = FieldPresenceFilterProxyModel(field_name, must_be_present)
filter_model.setSourceModel(algorithm_model)
assert filter_model.rowCount() < algorithm_model.rowCount()
for i in range(filter_model.rowCount()):
algorithm_name = filter_model.index(i, 0).data()
declaration = get_algorithm_declaration(
algorithm_model.prefix_path, algorithm_name
)
if must_be_present:
assert field_name in declaration
else:
assert field_name not in declaration
@pytest.mark.parametrize("must_be_present", [True, False])
@pytest.mark.parametrize(
"field_value", ["autonomous", "[autonomous|sequential]", ".*loop.*"]
)
@pytest.mark.parametrize("field_name", ["type"])
def test_field_value(
self, algorithm_model, field_name, must_be_present, field_value
):
filter_model = FieldPresenceFilterProxyModel(
field_name, must_be_present, field_value
)
filter_model.setSourceModel(algorithm_model)
assert filter_model.rowCount() < algorithm_model.rowCount()
for i in range(filter_model.rowCount()):
algorithm_name = filter_model.index(i, 0).data()
declaration = get_algorithm_declaration(
algorithm_model.prefix_path, algorithm_name
)
if must_be_present:
value = declaration[field_name]
match = re.match(field_value, value)
assert match is not None
else:
if field_name in declaration:
value = declaration[field_name]
match = re.match(field_value, value)
assert match is None
class TestEnvironmentModel:
"""Test that the environment model shows and return value as expected"""
......@@ -760,12 +706,20 @@ class PropertiesEditorTestMixin:
class TestExecutionPropertiesEditor(PropertiesEditorTestMixin, ParameterTestMixin):
"""Test that the AlgorithmEdior works as expected"""
"""Test that the AlgorithmEditor works as expected"""
editor_klass = ExecutionPropertiesEditor
declaration_field = "blocks"
parameter_field = "parameters"
@pytest.fixture(autouse=True)
def prefix_model(self, beat_context):
return PrefixModel(beat_context)
@pytest.fixture()
def algorithm_model(self):
return AlgorithmModel()
@pytest.fixture()
def properties_editor(self, beat_context, test_prefix, algorithm_model):
environment_model = EnvironmentModel()
......@@ -785,12 +739,11 @@ class TestBlockEditor(TestExecutionPropertiesEditor):
declaration_field = "blocks"
@pytest.fixture()
def properties_editor(self, beat_context, test_prefix, algorithm_model):
def properties_editor(self, beat_context, test_prefix):
environment_model = EnvironmentModel()
environment_model.setContext(beat_context)
editor = self.editor_klass("block_name", test_prefix)
editor.setAlgorithmModel(algorithm_model)
editor.setEnvironmentModel(environment_model)
editor.setQueueModel(QStringListModel(["Test"]))
return editor
......@@ -803,12 +756,11 @@ class TestAnalyzerBlockEditor(PropertiesEditorTestMixin):
declaration_field = "analyzers"
@pytest.fixture()
def properties_editor(self, beat_context, test_prefix, algorithm_model):
def properties_editor(self, beat_context, test_prefix):
environment_model = EnvironmentModel()
environment_model.setContext(beat_context)
editor = self.editor_klass("block_name", test_prefix)
editor.setAlgorithmModel(algorithm_model)
editor.setEnvironmentModel(environment_model)
editor.setQueueModel(QStringListModel(["Test"]))
return editor
......@@ -826,7 +778,6 @@ class TestLoopBlockEditor(TestExecutionPropertiesEditor):
environment_model.setContext(beat_context)
editor = self.editor_klass("block_name", test_prefix)
editor.setAlgorithmModel(algorithm_model)
editor.setEnvironmentModel(environment_model)
editor.setQueueModel(QStringListModel(["Test"]))
return editor
......
......@@ -23,7 +23,6 @@
# #
###############################################################################
import re
import os
import copy
import simplejson as json
......@@ -33,7 +32,6 @@ from PyQt5.QtCore import pyqtProperty
from PyQt5.QtCore import pyqtSignal
from PyQt5.QtCore import pyqtSlot
from PyQt5.QtCore import QStringListModel
from PyQt5.QtCore import QSortFilterProxyModel
from PyQt5.QtGui import QIcon
from PyQt5.QtGui import QStandardItem
......@@ -61,6 +59,8 @@ from PyQt5.QtSql import QSqlTableModel
from beat.core.experiment import PROCESSOR_PREFIX
from beat.core.experiment import EVALUATOR_PREFIX
from beat.backend.python.algorithm import Algorithm
from ..backend.asset import Asset
from ..backend.asset import AssetType
from ..backend.assetmodel import AssetModel
......@@ -113,51 +113,6 @@ def write_user_property(widget, value):
# Helper classes
class FieldPresenceFilterProxyModel(QSortFilterProxyModel):
"""Filter proxy model showing asset filtered based on their content"""
def __init__(self, field, must_have, value=None, parent=None):
"""
:param field str: field to search
:param must_have bool: whether the field must be in the declaration
"""
super().__init__(parent=parent)
self.setFieldFilter(field, must_have, value)
def setFieldFilter(self, field, must_have, value=None):
self.field = field
self.must_have = must_have
self.value = value
self.prog = re.compile(value) if value is not None else None
self.invalidate()
def filterAcceptsRow(self, source_row, source_parent):
"""Filter assets based on whether the field configured must or must not
be found in the declaration.
"""
asset_model = self.sourceModel()
index = asset_model.index(source_row, 0, source_parent)
path = asset_model.json_path(index.data())
with open(path, "rt") as json_file:
try:
json_data = json.load(json_file)
except json.JSONDecodeError:
return False
else:
has_field = self.field in json_data
if has_field and self.value is not None:
match = self.prog.match(json_data[self.field])
if self.must_have:
return match is not None
else:
return match is None
return has_field == self.must_have
class EnvironmentModel(QStandardItemModel):
"""Model wrapping the processing environment available"""
......@@ -504,13 +459,15 @@ class DatasetEditor(AbstractBaseEditor):
self.reset_button.clicked.connect(self.reset)
def reset(self):
self.dataset_combobox.setCurrentText(
"{}/{}/{}".format(
self.json_object["database"],
self.json_object["protocol"],
self.json_object["set"],
)
dataset_name = "{}/{}/{}".format(
self.json_object["database"],
self.json_object["protocol"],
self.json_object["set"],
)
self.dataset_combobox.setCurrentText(dataset_name)
if self.dataset_combobox.currentText() != dataset_name:
self.dataset_combobox.setCurrentIndex(-1)
def setDatasetModel(self, dataset_model):
self.dataset_combobox.setModel(dataset_model)
......@@ -524,9 +481,15 @@ class DatasetEditor(AbstractBaseEditor):
self.reset()
def dump(self):
database, db_version, protocol, set_ = self.dataset_combobox.currentText().split(
"/"
)
current_dataset = self.dataset_combobox.currentText()
try:
database, db_version, protocol, set_ = current_dataset.split("/")
except ValueError:
database = None
db_version = None
protocol = None
set_ = None
return {
"database": "{}/{}".format(database, db_version),
"protocol": protocol,
......@@ -598,6 +561,10 @@ class AlgorithmParametersEditor(AbstractBaseEditor):
while self.form_layout.rowCount():
self.form_layout.removeRow(0)
if not algorithm_name:
self.parameterCountChanged.emit(0)
return
algorithm = AssetType.ALGORITHM.klass(self.prefix_path, self.algorithm_name)
parameters = algorithm.parameters
if parameters is not None:
......@@ -705,6 +672,7 @@ class ExecutionPropertiesEditor(AbstractBaseEditor):
self.queue_changed = False
self._queue_enabled = True
self.parameter_item = None
self.algorithm_model = None
self.algorithm_combobox = QComboBox()
self.algorithm_combobox.setObjectName("algorithms")
self.environment_combobox = QComboBox()
......@@ -741,10 +709,7 @@ class ExecutionPropertiesEditor(AbstractBaseEditor):
layout.addWidget(groupbox)
self.algorithm_combobox.currentIndexChanged.connect(self.dataChanged)
self.algorithm_combobox.currentTextChanged.connect(self.algorithmChanged)
self.algorithm_combobox.currentTextChanged.connect(
lambda algorithm_name: self.parameters_editor.setup(algorithm_name)
)
self.algorithm_combobox.currentTextChanged.connect(self.__onAlgorithmChanged)
self.environment_combobox.currentIndexChanged.connect(self.dataChanged)
self.environment_combobox.currentIndexChanged.connect(
lambda *args: setattr(self, "environment_changed", True)
......@@ -790,13 +755,48 @@ class ExecutionPropertiesEditor(AbstractBaseEditor):
self._queue_enabled
)
@pyqtSlot(str)
def __onAlgorithmChanged(self, algorithm_name):
if not algorithm_name or not self.json_object:
return
# redo the io mapping
asset = Asset(self.prefix_path, AssetType.ALGORITHM, algorithm_name)
declaration = asset.declaration
inputs = []
outputs = []
for group in declaration["groups"]:
inputs += group.get("inputs", [])
outputs += group.get("outputs", [])
# The key is the algorithm IO name and the value is toolchain IO name
io_mapping = {}
if sorted(inputs) != sorted(self.io_mapping["inputs"].keys()):
io_mapping = {"inputs": {}}
for key, value in zip(inputs, self.io_mapping["inputs"].values()):
io_mapping["inputs"][key] = value
if outputs and sorted(outputs) != sorted(self.io_mapping["outputs"].keys()):
io_mapping["outputs"] = {}
for key, value in zip(outputs, self.io_mapping["outputs"].values()):
io_mapping["outputs"][key] = value
if io_mapping:
self.io_mapping = io_mapping
self.algorithmChanged.emit(algorithm_name)
def algorithm(self):
return self.algorithm_combobox.currentText()
algorithm = pyqtProperty(str, fget=algorithm, notify=algorithmChanged)
def setAlgorithmModel(self, model):
self.algorithm_combobox.setModel(model)
self.algorithm_model = model
self.algorithm_combobox.setModel(self.algorithm_model)
def setEnvironmentModel(self, model):
self.environment_combobox.setModel(model)
......@@ -813,22 +813,34 @@ class ExecutionPropertiesEditor(AbstractBaseEditor):
def load(self, json_object):
self.json_object = copy.deepcopy(json_object)
self.io_mapping = {"inputs": json_object["inputs"]}
inputs = json_object.get("inputs", {})
outputs = json_object.get("outputs")
self.io_mapping = {"inputs": inputs}
outputs = json_object.get("outputs", {})
if outputs:
self.io_mapping["outputs"] = outputs
self.algorithm_model.setInputCount(len(inputs))
self.algorithm_model.setOutputCount(len(outputs))
algorithm_name = json_object["algorithm"]
environment = json_object.get("environment")
parameters = json_object.get("parameters")
self.algorithm_combobox.setCurrentText(algorithm_name)
if self.algorithm_combobox.currentText() != algorithm_name:
raise RuntimeError(
"Algorithm {} not found in prefix".format(algorithm_name)
)
self.parameters_editor.setup(algorithm_name)
if algorithm_name:
self.algorithm_combobox.setCurrentText(algorithm_name)
if self.algorithm_combobox.currentText() != algorithm_name:
available_algorithms = "\n".join(
self.algorithm_model.index(i, 0).data()
for i in range(self.algorithm_model.rowCount())
)
raise RuntimeError(
"Algorithm {} not available in list:\n {}".format(
algorithm_name, available_algorithms
)
)
self.parameters_editor.setup(algorithm_name)
if environment:
env_text = "{} ({})".format(environment["name"], environment["version"])
......@@ -893,9 +905,14 @@ class BlockEditor(AbstractBlockEditor):
def __init__(self, block_name, prefix_path, parent=None):
super().__init__(block_name, prefix_path, parent)
self.proxy_model = FieldPresenceFilterProxyModel("results", False)
self.algorithm_model = AlgorithmModel()
self.algorithm_model.setAnalyzerEnabled(False)
self.algorithm_model.setTypes(
[Algorithm.LEGACY, Algorithm.SEQUENTIAL, Algorithm.AUTONOMOUS]
)
self.properties_editor = ExecutionPropertiesEditor(prefix_path)
self.properties_editor.setAlgorithmModel(self.proxy_model)
self.properties_editor.setAlgorithmModel(self.algorithm_model)
layout = QVBoxLayout(self)
layout.addWidget(QLabel(self.block_name))
......@@ -904,10 +921,7 @@ class BlockEditor(AbstractBlockEditor):
self.properties_editor.algorithmChanged.connect(self.algorithmChanged)
self.properties_editor.dataChanged.connect(self.dataChanged)
self.prefixPathChanged.connect(self.proxy_model.invalidate)
def setAlgorithmModel(self, model):
self.proxy_model.setSourceModel(model)
self.prefixPathChanged.connect(self.algorithm_model.update)
def selectedAlgorithm(self):
return self.properties_editor.algorithm
......@@ -932,18 +946,23 @@ class AnalyzerBlockEditor(BlockEditor):
def __init__(self, block_name, prefix_path, parent=None):
super().__init__(block_name, prefix_path, parent)
self.proxy_model.setFieldFilter("results", True)
self.algorithm_model.setAnalyzerEnabled(True)
class LoopBlockEditor(AbstractBlockEditor):
def __init__(self, block_name, prefix_path, parent=None):
super().__init__(block_name, prefix_path, parent)
self.processor_model = FieldPresenceFilterProxyModel(
"type", True, ".*loop_processor"
self.processor_model = AlgorithmModel()
self.processor_model.setAnalyzerEnabled(False)
self.processor_model.setTypes(
[Algorithm.AUTONOMOUS_LOOP_PROCESSOR, Algorithm.SEQUENTIAL_LOOP_PROCESSOR]
)
self.evaluator_model = FieldPresenceFilterProxyModel(
"type", True, ".*loop_evaluator"
self.evaluator_model = AlgorithmModel()
self.evaluator_model.setAnalyzerEnabled(False)
self.evaluator_model.setTypes(
[Algorithm.AUTONOMOUS_LOOP_EVALUATOR, Algorithm.SEQUENTIAL_LOOP_EVALUATOR]
)
self.processor_properties_editor = ExecutionPropertiesEditor(prefix_path)
......@@ -971,12 +990,8 @@ class LoopBlockEditor(AbstractBlockEditor):
self.evaluator_properties_editor.algorithmChanged.connect(self.algorithmChanged)
self.evaluator_properties_editor.dataChanged.connect(self.dataChanged)
self.prefixPathChanged.connect(self.processor_model.invalidate)
self.prefixPathChanged.connect(self.evaluator_model.invalidate)
def setAlgorithmModel(self, model):
self.processor_model.setSourceModel(model)
self.evaluator_model.setSourceModel(model)
self.prefixPathChanged.connect(self.processor_model.update)
self.prefixPathChanged.connect(self.evaluator_model.update)
def selectedAlgorithms(self):
return {
......@@ -1113,6 +1128,8 @@ class GlobalParametersEditor(AbstractBaseEditor):
self.json_object.pop(algorithm_name, None)
for algorithm_name in to_add:
if not algorithm_name:
continue
algorithm = AssetType.ALGORITHM.klass(self.prefix_path, algorithm_name)
if algorithm.valid and algorithm.parameters:
editor = AlgorithmParametersEditor(self.prefix_path)
......@@ -1176,6 +1193,8 @@ class ExperimentEditor(AbstractAssetEditor):
self.setObjectName(self.__class__.__name__)
self.set_title(self.tr("Experiment"))
self.prefix_model = PrefixModel()
self.processing_env_model = EnvironmentModel()
self.queue_model = QStringListModel()
self.queue_model.setStringList(["Local"])
......@@ -1220,6 +1239,8 @@ class ExperimentEditor(AbstractAssetEditor):
@pyqtSlot()
def __update(self):
self.prefix_model.setContext(self.context)
for object_ in [
self.algorithm_model,
self.dataset_model,
......@@ -1272,105 +1293,38 @@ class ExperimentEditor(AbstractAssetEditor):
for dataset in toolchain_declaration["datasets"]:
dataset_data = {"database": None, "protocol": None, "set": None}
dataset_outputs = dataset["outputs"]
for i in range(database_model.rowCount()):
done = False
database_name = database_model.index(i, 0).data()
database = AssetType.DATABASE.klass(self.prefix_path, database_name)
if not database.valid:
continue
for db_protocol in database.protocols:
sets = database.sets(db_protocol)
for set_name, set_data in sets.items():
set_outputs = list(set_data["outputs"])
if all(output in dataset_outputs for output in set_outputs):
dataset_data["database"] = database_name
dataset_data["protocol"] = db_protocol
dataset_data["set"] = set_name
done = True
break
if done:
break
if done:
break
name = dataset["name"]
experiment_declaration["datasets"][name] = dataset_data
# blocks
def build_block_data(block, algorithm_filter):
block_data = {"algorithm": None, "inputs": {}}
input_count = len(block["inputs"])
output_count = len(block.get("outputs", []))
if output_count:
block_data["outputs"] = {}
for i in range(algorithm_filter.rowCount()):
asset_name = algorithm_filter.index(i, 0).data()
asset = Asset(self.prefix_path, AssetType.ALGORITHM, asset_name)
asset_declaration = asset.declaration
inputs = {}
outputs = {}
for group in asset_declaration["groups"]:
inputs.update(group["inputs"])
outputs.update(group.get("outputs", {}))
if len(inputs) == input_count and len(outputs) == output_count:
keys = list(inputs)
for index, input_ in enumerate(block["inputs"]):
algorithm_input = keys[index]
block_data["inputs"][algorithm_input] = input_
if output_count:
keys = list(outputs)
for index, output in enumerate(block["outputs"]):
algorithm_output = keys[index]
block_data["outputs"][algorithm_output] = output
block_data["algorithm"] = asset_name
break
def build_block_data(block):
block_data = {"algorithm": None}
block_data["inputs"] = {
f"{input_}_tbr": input_ for input_ in block["inputs"]
}
outputs = block.get("outputs", [])
if outputs:
block_data["outputs"] = {
f"{output}_tbr": output for output in outputs
}
return block_data
algorithm_filter = FieldPresenceFilterProxyModel(
"type", True, "[sequential|autonomous]"
)
algorithm_filter.setSourceModel(self.algorithm_model)
for block in toolchain_declaration["blocks"]:
name = block["name"]
experiment_declaration["blocks"][name] = build_block_data(
block, algorithm_filter
)
experiment_declaration["blocks"][name] = build_block_data(block)
# analyzers
algorithm_filter.setFieldFilter("results", True)
for analyzer in toolchain_declaration["analyzers"]:
name = analyzer["name"]
experiment_declaration["analyzers"][name] = build_block_data(
analyzer, algorithm_filter
)
experiment_declaration["analyzers"][name] = build_block_data(analyzer)
# loops
if "loops" in toolchain_declaration:
experiment_declaration["schema_version"] = 2
experiment_declaration["loops"] = {}
processor_filter = FieldPresenceFilterProxyModel(
"type", True, ".*loop_processor"
)
processor_filter.setSourceModel(self.algorithm_model)
evaluator_filter = FieldPresenceFilterProxyModel(
"type", True, ".*loop_evaluator"
)
evaluator_filter.setSourceModel(self.algorithm_model)
for loop in toolchain_declaration["loops"]:
loop_data = {}
......@@ -1395,8 +1349,8 @@ class ExperimentEditor(AbstractAssetEditor):
]
}
processor_data = build_block_data(processor_block, processor_filter)
evaluator_data = build_block_data(loop_block, evaluator_filter)
processor_data = build_block_data(processor_block)
evaluator_data = build_block_data(loop_block)
loop_data.update(
{
f"{PROCESSOR_PREFIX}{key}": value
......@@ -1429,6 +1383,8 @@ class ExperimentEditor(AbstractAssetEditor):
def _load_json(self, json_object):
"""Load the json object passed as parameter"""
self.prefix_model.refresh()
for widget in [
self.datasets_widget,
self.blocks_widget,
......@@ -1457,7 +1413,6 @@ class ExperimentEditor(AbstractAssetEditor):
if items: