Commit b756cd0c authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

New interactive API

parent 0fd5ca8d
Pipeline #41901 failed with stage
in 2 minutes and 49 seconds
from collections import namedtuple
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from beat.backend.python.database import View
Entry = namedtuple("Entry", ["features", "labels"])
class IrisView(View):
def __init__(self, **kwargs):
super().__init__(**kwargs)
iris = load_iris()
X, y = iris.data, iris.target
# this will convert our problem to a binary classification problem
y = np.clip(y, 0, 1)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=1 / 3, random_state=0, shuffle=True
)
if self.IRIS_GROUP == "train":
self.X, self.y = X_train, y_train
else:
self.X, self.y = X_test, y_test
def index(self, root_folder, parameters):
return [Entry(x_, y_) for x_, y_ in zip(self.X, self.y)]
def get(self, output, index):
obj = self.objs[index]
return {"value": getattr(obj, output)}
class Train(IrisView):
IRIS_GROUP = "train"
class Test(IrisView):
IRIS_GROUP = "test"
......@@ -2,17 +2,15 @@ import logging
import os
import tempfile
import unittest
from collections import namedtuple
import numpy as np
import pkg_resources
import simplejson
from beat.backend.python.algorithm import Algorithm_
from beat.backend.python.algorithm import Analyzer_
from beat.backend.python.database import Database_
from beat.backend.python.database import View
from beat.backend.python.database import ViewRunner
from beat.backend.python.dataformat import DataFormat_
from beat.backend.python.algorithm import Algorithm
from beat.backend.python.algorithm import Analyzer
from beat.backend.python.database import Database
from beat.backend.python.dataformat import DataFormat
from beat.backend.python.protocoltemplate import ProtocolTemplate
from beat.core.data import CachedDataSource
from beat.core.execution import LocalExecutor
from beat.core.experiment import Experiment
......@@ -68,71 +66,55 @@ def print_results(executor, analyzer):
print(f" Results:\n{r}")
config
config.set("user", "amir")
config.set("prefix", "/somewhere")
class IrisLDATest(unittest.TestCase):
def setup_data_formats(self):
self.features_type = DataFormat_(
definition={"value": [0, float]}, name="user/1d_float_arrays/1"
)
self.labels_type = DataFormat_(
definition={"value": int}, name="user/integers/1"
)
self.model_type = DataFormat_(definition={"text": str}, name="user/strings/1")
self.scores_type = DataFormat_(
definition={"value": float}, name="user/floats/1"
)
def setup_database(self):
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
iris = load_iris()
X, y = iris.data, iris.target
# this will convert our problem to a binary classification problem
y = np.clip(y, 0, 1)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=1 / 3, random_state=0, shuffle=True
self.features_type = DataFormat.new(
definition={"value": [0, float]}, name="array_1d_floats"
)
self.labels_type = DataFormat.new(definition={"value": int}, name="integers")
self.model_type = DataFormat.new(definition={"text": str}, name="strings")
self.scores_type = DataFormat.new(definition={"value": float}, name="floats")
Entry = namedtuple("Entry", ["features", "labels"])
def get_view(X, y, name):
class IrisView(View):
def index(self, root_folder, parameters):
return [Entry(x_, y_) for x_, y_ in zip(X, y)]
def get(self, output, index):
obj = self.objs[index]
return {"value": getattr(obj, output)}
view = ViewRunner(
module=IrisView(),
definition={
def setup_protocol_templates(self):
self.protocoltemplate = ProtocolTemplate.new(
sets=[
{
"name": "train",
"outputs": {
"features": self.features_type,
"labels": self.labels_type,
}
},
},
name=name,
)
return view
train_view = get_view(X_train, y_train, "train")
test_view = get_view(X_test, y_test, "test")
{
"name": "test",
"outputs": {
"features": self.features_type,
"labels": self.labels_type,
},
},
],
name="iris_two_class",
)
self.validate("protocoltemplate", self.protocoltemplate)
self.database = Database_(
protocols={"main": {"sets": {"train": train_view, "test": test_view}}},
name="iris_two_class/1",
def setup_database(self):
self.database = Database.new(
code_path=pkg_resources.resource_filename(__name__, "data/iris_views.py"),
protocols=[
{
"name": "main",
"template": self.protocoltemplate,
"views": {"test": {"view": "Test"}, "train": {"view": "Train"}},
}
],
name="iris_two_class",
)
self.validate("database", self.database)
def setup_algorithms(self):
import pickle
import base64
import pickle
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
class TrainLDA:
......@@ -161,7 +143,7 @@ class IrisLDATest(unittest.TestCase):
return True
self.train_algorithm = Algorithm_(
self.train_algorithm = Algorithm(
algorithm=TrainLDA(),
groups=[
{
......@@ -172,7 +154,7 @@ class IrisLDATest(unittest.TestCase):
"outputs": {"model": {"type": self.model_type}},
}
],
type=Algorithm_.SEQUENTIAL,
type=Algorithm.SEQUENTIAL,
name="user/train_lda/1",
)
......@@ -200,7 +182,7 @@ class IrisLDATest(unittest.TestCase):
outputs["scores"].write({"value": out})
return True
self.test_algorithm = Algorithm_(
self.test_algorithm = Algorithm(
algorithm=TestLDA(),
groups=[
{
......@@ -210,7 +192,7 @@ class IrisLDATest(unittest.TestCase):
},
{"name": "model", "inputs": {"model": {"type": self.model_type}}},
],
type=Algorithm_.SEQUENTIAL,
type=Algorithm.SEQUENTIAL,
name="user/score_with_lda/1",
)
self.validate("test_algorithm", self.test_algorithm)
......@@ -239,7 +221,7 @@ class IrisLDATest(unittest.TestCase):
return True
self.analyzer = Analyzer_(
self.analyzer = Analyzer(
algorithm=AccuracyAnalyzer(),
input_groups=[
{
......@@ -250,7 +232,7 @@ class IrisLDATest(unittest.TestCase):
},
],
results={"accuracy": {"type": self.scores_type, "display": True}},
type=Algorithm_.SEQUENTIAL,
type=Algorithm.SEQUENTIAL,
name="user/analyzer/1",
)
self.validate("analyzer", self.analyzer)
......@@ -406,6 +388,7 @@ class IrisLDATest(unittest.TestCase):
def run_test(self):
self.setup_data_formats()
self.setup_protocol_templates()
self.setup_database()
self.setup_algorithms()
self.setup_analyzer()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment