Skip to content
Snippets Groups Projects
Commit f323cd1b authored by Philip ABBET's avatar Philip ABBET
Browse files

[unittests] Add tests for the latest changes in beat.backend.python

parent 029625dd
No related branches found
No related tags found
No related merge requests found
Showing
with 414 additions and 58 deletions
......@@ -235,6 +235,8 @@ class BaseExecutor(object):
"""
logger.info("Start the execution of '%s'", self.algorithm.name)
self._prepare_inputs()
self._prepare_outputs()
......
File mode changed from 100755 to 100644
{
"language": "python",
"splittable": false,
"groups": [
{
"inputs": {
"a": {
"type": "user/single_integer/1"
},
"b": {
"type": "user/single_integer/1"
}
},
"outputs": {
"sum": {
"type": "user/single_integer/1"
}
}
}
]
}
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.core module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
import numpy
class Algorithm:
def process(self, inputs, outputs):
total = 0
if inputs['a'].isDataUnitDone():
total += inputs['a'].data.value
if inputs['b'].isDataUnitDone():
total += inputs['b'].data.value
outputs['sum'].write({
'value': numpy.int32(total)
})
return True
......@@ -75,6 +75,21 @@
}
}
]
},
{
"name": "different_frequencies",
"template": "different_frequencies",
"sets": [
{
"name": "double",
"template": "double",
"view": "DifferentFrequencies",
"outputs": {
"a": "user/single_integer/1",
"b": "user/single_integer/1"
}
}
]
}
]
}
......@@ -140,3 +140,53 @@ class Labelled:
self.remaining = self.remaining[1:]
return True
class DifferentFrequencies:
def setup(self, root_folder, outputs, parameters):
self.outputs = outputs
self.values_a = [(1, 0, 3), (2, 4, 7)]
self.values_b = [(10, 0, 0), (20, 1, 1), (30, 2, 2), (40, 3, 3),
(50, 4, 4), (60, 5, 5), (70, 6, 6), (80, 7, 7)]
return True
def done(self):
if self.outputs['b'].isConnected():
return (self.outputs['b'].last_written_data_index == 7)
else:
return (self.outputs['a'].last_written_data_index == 7)
def next(self):
if self.outputs['b'].isConnected():
current_index = self.values_b[0][1]
if (len(self.values_a) > 0) and (self.values_a[0][1] != current_index):
self.outputs['a'].write({
'value': numpy.int32(self.values_a[0][0]),
},
end_data_index=self.values_a[0][2]
)
self.values_a = self.values_a[1:]
self.outputs['b'].write({
'value': numpy.int32(self.values_b[0][0]),
},
end_data_index=self.values_b[0][2]
)
self.values_b = self.values_b[1:]
else:
self.outputs['a'].write({
'value': numpy.int32(self.values_a[0][0]),
},
end_data_index=self.values_a[0][2]
)
self.values_a = self.values_a[1:]
return True
{
"datasets": {
"set": {
"database": "integers_db/1",
"protocol": "different_frequencies",
"set": "double"
}
},
"blocks": {
"echo": {
"algorithm": "user/integers_echo/1",
"inputs": {
"in_data": "in"
},
"outputs": {
"out_data": "out"
}
},
"add": {
"algorithm": "user/sum_only_done_data_units/1",
"inputs": {
"a": "a",
"b": "b"
},
"outputs": {
"sum": "sum"
}
}
},
"analyzers": {
"analysis": {
"algorithm": "user/integers_analysis/1",
"inputs": {
"input": "in"
}
}
},
"globals": {
"queue": "queue",
"environment": {
"name": "Python 2.7",
"version": "1.2.0"
}
}
}
{
"datasets": [
{
"outputs": [
"a",
"b"
],
"name": "set"
}
],
"blocks": [
{
"name": "echo",
"inputs": [
"in"
],
"outputs": [
"out"
],
"synchronized_channel": "set"
},
{
"name": "add",
"inputs": [
"a",
"b"
],
"outputs": [
"sum"
],
"synchronized_channel": "set"
}
],
"analyzers": [
{
"name": "analysis",
"inputs": [
"in"
],
"synchronized_channel": "set"
}
],
"connections": [
{
"from": "set.b",
"to": "echo.in",
"channel": "set"
},
{
"from": "set.a",
"to": "add.a",
"channel": "set"
},
{
"from": "echo.out",
"to": "add.b",
"channel": "set"
},
{
"from": "add.sum",
"to": "analysis.in",
"channel": "set"
}
],
"representation": {
"blocks": {
},
"connections": {
},
"channel_colors": {
}
}
}
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -263,6 +263,10 @@ class TestExecution(unittest.TestCase):
}
]) is None
@slow
def test_preprocessing_1(self):
assert self.execute('user/user/preprocessing/1/different_frequencies', [{'sum': 363, 'nb': 8}]) is None
# For benchmark purposes
# def test_double_1_large(self):
# import time
......
......@@ -34,7 +34,6 @@ logger = logging.getLogger(__name__)
import unittest
import zmq
import nose.tools
from ..message_handler import MessageHandler
from ..dataformat import DataFormat
......@@ -50,39 +49,17 @@ from .mocks import MockDataSource_Crash
from . import prefix
#----------------------------------------------------------
class TestMessageHandler(unittest.TestCase):
def setUp(self):
dataformat = DataFormat(prefix, 'user/single_integer/1')
data_source_a = MockDataSource([
dataformat.type(value=10),
dataformat.type(value=20),
],
[
(0, 0),
(1, 1),
]
)
input_a = Input('a', 'user/single_integer/1', data_source_a)
data_source_b = MockDataSource([
dataformat.type(value=100),
dataformat.type(value=200),
],
[
(0, 0),
(1, 1),
]
)
input_b = Input('b', 'user/single_integer/1', data_source_b)
class TestMessageHandlerBase(unittest.TestCase):
def create_remote_inputs(self, dataformat, data_sources):
group = InputGroup('channel', restricted_access=False)
group.add(input_a)
group.add(input_b)
for name, data_source in data_sources.items():
input = Input(name, dataformat.name, data_source)
group.add(input)
self.input_list = InputList()
self.input_list.add(group)
......@@ -101,12 +78,11 @@ class TestMessageHandler(unittest.TestCase):
client_socket = self.client_context.socket(zmq.PAIR)
client_socket.connect(address)
self.remote_input_a = RemoteInput('a', dataformat, client_socket)
self.remote_input_b = RemoteInput('b', dataformat, client_socket)
self.remote_group = InputGroup('channel', restricted_access=False)
self.remote_group.add(self.remote_input_a)
self.remote_group.add(self.remote_input_b)
for name in data_sources.keys():
remote_input = RemoteInput(name, dataformat, client_socket)
self.remote_group.add(remote_input)
self.remote_input_list = InputList()
self.remote_input_list.add(self.remote_group)
......@@ -120,53 +96,178 @@ class TestMessageHandler(unittest.TestCase):
self.message_handler = None
#----------------------------------------------------------
class TestSameFrequencyInputs(TestMessageHandlerBase):
def setUp(self):
dataformat = DataFormat(prefix, 'user/single_integer/1')
self.create_remote_inputs(
DataFormat(prefix, 'user/single_integer/1'),
dict(
a = MockDataSource([
dataformat.type(value=10),
dataformat.type(value=20),
],
[
(0, 0),
(1, 1),
]
),
b = MockDataSource([
dataformat.type(value=100),
dataformat.type(value=200),
],
[
(0, 0),
(1, 1),
]
)
)
)
self.remote_input_a = self.remote_input_list['a']
self.remote_input_b = self.remote_input_list['b']
def test_input_has_more_data(self):
assert self.remote_input_a.hasMoreData()
self.assertTrue(self.remote_input_a.hasMoreData())
def test_input_next(self):
self.remote_input_a.next()
nose.tools.eq_(self.remote_input_a.data.value, 10)
self.assertEqual(self.remote_input_a.data.value, 10)
def test_input_full_cycle(self):
assert self.remote_input_a.hasMoreData()
self.assertTrue(self.remote_input_a.hasMoreData())
self.remote_input_a.next()
nose.tools.eq_(self.remote_input_a.data.value, 10)
self.assertEqual(self.remote_input_a.data.value, 10)
assert self.remote_input_a.hasDataChanged()
assert self.remote_input_a.hasMoreData()
assert self.remote_input_a.isDataUnitDone()
self.assertTrue(self.remote_input_a.hasDataChanged())
self.assertTrue(self.remote_input_a.hasMoreData())
self.assertTrue(self.remote_input_a.isDataUnitDone())
self.remote_input_a.next()
nose.tools.eq_(self.remote_input_a.data.value, 20)
self.assertEqual(self.remote_input_a.data.value, 20)
assert self.remote_input_a.hasDataChanged()
assert not self.remote_input_a.hasMoreData()
assert self.remote_input_a.isDataUnitDone()
self.assertTrue(self.remote_input_a.hasDataChanged())
self.assertTrue(not self.remote_input_a.hasMoreData())
self.assertTrue(self.remote_input_a.isDataUnitDone())
def test_group_has_more_data(self):
assert self.remote_group.hasMoreData()
self.assertTrue(self.remote_group.hasMoreData())
def test_group_next(self):
self.remote_group.next()
nose.tools.eq_(self.remote_input_a.data.value, 10)
nose.tools.eq_(self.remote_input_b.data.value, 100)
self.assertEqual(self.remote_input_a.data.value, 10)
self.assertEqual(self.remote_input_b.data.value, 100)
def test_group_full_cycle(self):
assert self.remote_group.hasMoreData()
self.assertTrue(self.remote_group.hasMoreData())
self.remote_group.next()
self.assertEqual(self.remote_input_a.data.value, 10)
self.assertEqual(self.remote_input_b.data.value, 100)
self.assertTrue(self.remote_group.hasMoreData())
self.remote_group.next()
nose.tools.eq_(self.remote_input_a.data.value, 10)
nose.tools.eq_(self.remote_input_b.data.value, 100)
self.assertEqual(self.remote_input_a.data.value, 20)
self.assertEqual(self.remote_input_b.data.value, 200)
self.assertTrue(not self.remote_group.hasMoreData())
#----------------------------------------------------------
class TestDifferentFrequenciesInputs(TestMessageHandlerBase):
def setUp(self):
dataformat = DataFormat(prefix, 'user/single_integer/1')
self.create_remote_inputs(
DataFormat(prefix, 'user/single_integer/1'),
dict(
a = MockDataSource([
dataformat.type(value=10),
dataformat.type(value=20),
],
[
(0, 3),
(4, 7),
]
),
b = MockDataSource([
dataformat.type(value=100),
dataformat.type(value=200),
dataformat.type(value=300),
dataformat.type(value=400),
dataformat.type(value=500),
dataformat.type(value=600),
dataformat.type(value=700),
dataformat.type(value=800),
],
[
(0, 0),
(1, 1),
(2, 2),
(3, 3),
(4, 4),
(5, 5),
(6, 6),
(7, 7),
]
)
)
)
self.remote_input_a = self.remote_input_list['a']
self.remote_input_b = self.remote_input_list['b']
def test_group_full_cycle(self):
self.assertTrue(self.remote_group.hasMoreData())
self.assertTrue(self.remote_input_a.hasMoreData())
self.assertTrue(self.remote_input_b.hasMoreData())
self.assertTrue(self.remote_input_a.isDataUnitDone())
self.assertTrue(self.remote_input_b.isDataUnitDone())
for i in range(0, 7):
self.remote_group.next()
self.assertEqual(self.remote_input_a.data.value, (i // 4 + 1) * 10)
self.assertEqual(self.remote_input_b.data.value, (i + 1) * 100)
self.assertTrue(self.remote_group.hasMoreData())
if i < 4:
self.assertTrue(self.remote_input_a.hasMoreData())
else:
self.assertFalse(self.remote_input_a.hasMoreData())
self.assertTrue(self.remote_input_b.hasMoreData())
if i == 3:
self.assertTrue(self.remote_input_a.isDataUnitDone())
else:
self.assertFalse(self.remote_input_a.isDataUnitDone())
self.assertTrue(self.remote_input_b.isDataUnitDone())
assert self.remote_group.hasMoreData()
self.remote_group.next()
nose.tools.eq_(self.remote_input_a.data.value, 20)
nose.tools.eq_(self.remote_input_b.data.value, 200)
assert not self.remote_group.hasMoreData()
self.assertEqual(self.remote_input_a.data.value, 20)
self.assertEqual(self.remote_input_b.data.value, 800)
self.assertFalse(self.remote_group.hasMoreData())
self.assertFalse(self.remote_input_a.hasMoreData())
self.assertFalse(self.remote_input_b.hasMoreData())
self.assertTrue(self.remote_input_a.isDataUnitDone())
self.assertTrue(self.remote_input_b.isDataUnitDone())
#----------------------------------------------------------
......
#! /bin/bash
IMAGES=(
docker.idiap.ch/beat/beat.env.system.python:1.2.0r0
docker.idiap.ch/beat/beat.env.db.examples:1.3.0r0
docker.idiap.ch/beat/beat.env.system.python:1.2.0r1
docker.idiap.ch/beat/beat.env.db.examples:1.3.0r1
docker.idiap.ch/beat/beat.env.cxx:1.0.2
docker.idiap.ch/beat/beat.env.client:1.2.0
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment