Commit f323cd1b authored by Philip ABBET's avatar Philip ABBET

[unittests] Add tests for the latest changes in beat.backend.python

parent 029625dd
......@@ -235,6 +235,8 @@ class BaseExecutor(object):
"""
logger.info("Start the execution of '%s'", self.algorithm.name)
self._prepare_inputs()
self._prepare_outputs()
......
File mode changed from 100755 to 100644
{
"language": "python",
"splittable": false,
"groups": [
{
"inputs": {
"a": {
"type": "user/single_integer/1"
},
"b": {
"type": "user/single_integer/1"
}
},
"outputs": {
"sum": {
"type": "user/single_integer/1"
}
}
}
]
}
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###############################################################################
# #
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.core module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
import numpy
class Algorithm:
def process(self, inputs, outputs):
total = 0
if inputs['a'].isDataUnitDone():
total += inputs['a'].data.value
if inputs['b'].isDataUnitDone():
total += inputs['b'].data.value
outputs['sum'].write({
'value': numpy.int32(total)
})
return True
......@@ -75,6 +75,21 @@
}
}
]
},
{
"name": "different_frequencies",
"template": "different_frequencies",
"sets": [
{
"name": "double",
"template": "double",
"view": "DifferentFrequencies",
"outputs": {
"a": "user/single_integer/1",
"b": "user/single_integer/1"
}
}
]
}
]
}
......@@ -140,3 +140,53 @@ class Labelled:
self.remaining = self.remaining[1:]
return True
class DifferentFrequencies:
def setup(self, root_folder, outputs, parameters):
self.outputs = outputs
self.values_a = [(1, 0, 3), (2, 4, 7)]
self.values_b = [(10, 0, 0), (20, 1, 1), (30, 2, 2), (40, 3, 3),
(50, 4, 4), (60, 5, 5), (70, 6, 6), (80, 7, 7)]
return True
def done(self):
if self.outputs['b'].isConnected():
return (self.outputs['b'].last_written_data_index == 7)
else:
return (self.outputs['a'].last_written_data_index == 7)
def next(self):
if self.outputs['b'].isConnected():
current_index = self.values_b[0][1]
if (len(self.values_a) > 0) and (self.values_a[0][1] != current_index):
self.outputs['a'].write({
'value': numpy.int32(self.values_a[0][0]),
},
end_data_index=self.values_a[0][2]
)
self.values_a = self.values_a[1:]
self.outputs['b'].write({
'value': numpy.int32(self.values_b[0][0]),
},
end_data_index=self.values_b[0][2]
)
self.values_b = self.values_b[1:]
else:
self.outputs['a'].write({
'value': numpy.int32(self.values_a[0][0]),
},
end_data_index=self.values_a[0][2]
)
self.values_a = self.values_a[1:]
return True
{
"datasets": {
"set": {
"database": "integers_db/1",
"protocol": "different_frequencies",
"set": "double"
}
},
"blocks": {
"echo": {
"algorithm": "user/integers_echo/1",
"inputs": {
"in_data": "in"
},
"outputs": {
"out_data": "out"
}
},
"add": {
"algorithm": "user/sum_only_done_data_units/1",
"inputs": {
"a": "a",
"b": "b"
},
"outputs": {
"sum": "sum"
}
}
},
"analyzers": {
"analysis": {
"algorithm": "user/integers_analysis/1",
"inputs": {
"input": "in"
}
}
},
"globals": {
"queue": "queue",
"environment": {
"name": "Python 2.7",
"version": "1.2.0"
}
}
}
{
"datasets": [
{
"outputs": [
"a",
"b"
],
"name": "set"
}
],
"blocks": [
{
"name": "echo",
"inputs": [
"in"
],
"outputs": [
"out"
],
"synchronized_channel": "set"
},
{
"name": "add",
"inputs": [
"a",
"b"
],
"outputs": [
"sum"
],
"synchronized_channel": "set"
}
],
"analyzers": [
{
"name": "analysis",
"inputs": [
"in"
],
"synchronized_channel": "set"
}
],
"connections": [
{
"from": "set.b",
"to": "echo.in",
"channel": "set"
},
{
"from": "set.a",
"to": "add.a",
"channel": "set"
},
{
"from": "echo.out",
"to": "add.b",
"channel": "set"
},
{
"from": "add.sum",
"to": "analysis.in",
"channel": "set"
}
],
"representation": {
"blocks": {
},
"connections": {
},
"channel_colors": {
}
}
}
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -263,6 +263,10 @@ class TestExecution(unittest.TestCase):
}
]) is None
@slow
def test_preprocessing_1(self):
assert self.execute('user/user/preprocessing/1/different_frequencies', [{'sum': 363, 'nb': 8}]) is None
# For benchmark purposes
# def test_double_1_large(self):
# import time
......
......@@ -34,7 +34,6 @@ logger = logging.getLogger(__name__)
import unittest
import zmq
import nose.tools
from ..message_handler import MessageHandler
from ..dataformat import DataFormat
......@@ -50,39 +49,17 @@ from .mocks import MockDataSource_Crash
from . import prefix
#----------------------------------------------------------
class TestMessageHandler(unittest.TestCase):
def setUp(self):
dataformat = DataFormat(prefix, 'user/single_integer/1')
data_source_a = MockDataSource([
dataformat.type(value=10),
dataformat.type(value=20),
],
[
(0, 0),
(1, 1),
]
)
input_a = Input('a', 'user/single_integer/1', data_source_a)
data_source_b = MockDataSource([
dataformat.type(value=100),
dataformat.type(value=200),
],
[
(0, 0),
(1, 1),
]
)
input_b = Input('b', 'user/single_integer/1', data_source_b)
class TestMessageHandlerBase(unittest.TestCase):
def create_remote_inputs(self, dataformat, data_sources):
group = InputGroup('channel', restricted_access=False)
group.add(input_a)
group.add(input_b)
for name, data_source in data_sources.items():
input = Input(name, dataformat.name, data_source)
group.add(input)
self.input_list = InputList()
self.input_list.add(group)
......@@ -101,12 +78,11 @@ class TestMessageHandler(unittest.TestCase):
client_socket = self.client_context.socket(zmq.PAIR)
client_socket.connect(address)
self.remote_input_a = RemoteInput('a', dataformat, client_socket)
self.remote_input_b = RemoteInput('b', dataformat, client_socket)
self.remote_group = InputGroup('channel', restricted_access=False)
self.remote_group.add(self.remote_input_a)
self.remote_group.add(self.remote_input_b)
for name in data_sources.keys():
remote_input = RemoteInput(name, dataformat, client_socket)
self.remote_group.add(remote_input)
self.remote_input_list = InputList()
self.remote_input_list.add(self.remote_group)
......@@ -120,53 +96,178 @@ class TestMessageHandler(unittest.TestCase):
self.message_handler = None
#----------------------------------------------------------
class TestSameFrequencyInputs(TestMessageHandlerBase):
def setUp(self):
dataformat = DataFormat(prefix, 'user/single_integer/1')
self.create_remote_inputs(
DataFormat(prefix, 'user/single_integer/1'),
dict(
a = MockDataSource([
dataformat.type(value=10),
dataformat.type(value=20),
],
[
(0, 0),
(1, 1),
]
),
b = MockDataSource([
dataformat.type(value=100),
dataformat.type(value=200),
],
[
(0, 0),
(1, 1),
]
)
)
)
self.remote_input_a = self.remote_input_list['a']
self.remote_input_b = self.remote_input_list['b']
def test_input_has_more_data(self):
assert self.remote_input_a.hasMoreData()
self.assertTrue(self.remote_input_a.hasMoreData())
def test_input_next(self):
self.remote_input_a.next()
nose.tools.eq_(self.remote_input_a.data.value, 10)
self.assertEqual(self.remote_input_a.data.value, 10)
def test_input_full_cycle(self):
assert self.remote_input_a.hasMoreData()
self.assertTrue(self.remote_input_a.hasMoreData())
self.remote_input_a.next()
nose.tools.eq_(self.remote_input_a.data.value, 10)
self.assertEqual(self.remote_input_a.data.value, 10)
assert self.remote_input_a.hasDataChanged()
assert self.remote_input_a.hasMoreData()
assert self.remote_input_a.isDataUnitDone()
self.assertTrue(self.remote_input_a.hasDataChanged())
self.assertTrue(self.remote_input_a.hasMoreData())
self.assertTrue(self.remote_input_a.isDataUnitDone())
self.remote_input_a.next()
nose.tools.eq_(self.remote_input_a.data.value, 20)
self.assertEqual(self.remote_input_a.data.value, 20)
assert self.remote_input_a.hasDataChanged()
assert not self.remote_input_a.hasMoreData()
assert self.remote_input_a.isDataUnitDone()
self.assertTrue(self.remote_input_a.hasDataChanged())
self.assertTrue(not self.remote_input_a.hasMoreData())
self.assertTrue(self.remote_input_a.isDataUnitDone())
def test_group_has_more_data(self):
assert self.remote_group.hasMoreData()
self.assertTrue(self.remote_group.hasMoreData())
def test_group_next(self):
self.remote_group.next()
nose.tools.eq_(self.remote_input_a.data.value, 10)
nose.tools.eq_(self.remote_input_b.data.value, 100)
self.assertEqual(self.remote_input_a.data.value, 10)
self.assertEqual(self.remote_input_b.data.value, 100)
def test_group_full_cycle(self):
assert self.remote_group.hasMoreData()
self.assertTrue(self.remote_group.hasMoreData())
self.remote_group.next()
self.assertEqual(self.remote_input_a.data.value, 10)
self.assertEqual(self.remote_input_b.data.value, 100)
self.assertTrue(self.remote_group.hasMoreData())
self.remote_group.next()
nose.tools.eq_(self.remote_input_a.data.value, 10)
nose.tools.eq_(self.remote_input_b.data.value, 100)
self.assertEqual(self.remote_input_a.data.value, 20)
self.assertEqual(self.remote_input_b.data.value, 200)
self.assertTrue(not self.remote_group.hasMoreData())
#----------------------------------------------------------
class TestDifferentFrequenciesInputs(TestMessageHandlerBase):
def setUp(self):
dataformat = DataFormat(prefix, 'user/single_integer/1')
self.create_remote_inputs(
DataFormat(prefix, 'user/single_integer/1'),
dict(
a = MockDataSource([
dataformat.type(value=10),
dataformat.type(value=20),
],
[
(0, 3),
(4, 7),
]
),
b = MockDataSource([
dataformat.type(value=100),
dataformat.type(value=200),
dataformat.type(value=300),
dataformat.type(value=400),
dataformat.type(value=500),
dataformat.type(value=600),
dataformat.type(value=700),
dataformat.type(value=800),
],
[
(0, 0),
(1, 1),
(2, 2),
(3, 3),
(4, 4),
(5, 5),
(6, 6),
(7, 7),
]
)
)
)
self.remote_input_a = self.remote_input_list['a']
self.remote_input_b = self.remote_input_list['b']
def test_group_full_cycle(self):
self.assertTrue(self.remote_group.hasMoreData())
self.assertTrue(self.remote_input_a.hasMoreData())
self.assertTrue(self.remote_input_b.hasMoreData())
self.assertTrue(self.remote_input_a.isDataUnitDone())
self.assertTrue(self.remote_input_b.isDataUnitDone())
for i in range(0, 7):
self.remote_group.next()
self.assertEqual(self.remote_input_a.data.value, (i // 4 + 1) * 10)
self.assertEqual(self.remote_input_b.data.value, (i + 1) * 100)
self.assertTrue(self.remote_group.hasMoreData())
if i < 4:
self.assertTrue(self.remote_input_a.hasMoreData())
else:
self.assertFalse(self.remote_input_a.hasMoreData())
self.assertTrue(self.remote_input_b.hasMoreData())
if i == 3:
self.assertTrue(self.remote_input_a.isDataUnitDone())
else:
self.assertFalse(self.remote_input_a.isDataUnitDone())
self.assertTrue(self.remote_input_b.isDataUnitDone())
assert self.remote_group.hasMoreData()
self.remote_group.next()
nose.tools.eq_(self.remote_input_a.data.value, 20)
nose.tools.eq_(self.remote_input_b.data.value, 200)
assert not self.remote_group.hasMoreData()
self.assertEqual(self.remote_input_a.data.value, 20)
self.assertEqual(self.remote_input_b.data.value, 800)
self.assertFalse(self.remote_group.hasMoreData())
self.assertFalse(self.remote_input_a.hasMoreData())
self.assertFalse(self.remote_input_b.hasMoreData())
self.assertTrue(self.remote_input_a.isDataUnitDone())
self.assertTrue(self.remote_input_b.isDataUnitDone())
#----------------------------------------------------------
......
#! /bin/bash
IMAGES=(
docker.idiap.ch/beat/beat.env.system.python:1.2.0r0
docker.idiap.ch/beat/beat.env.db.examples:1.3.0r0
docker.idiap.ch/beat/beat.env.system.python:1.2.0r1
docker.idiap.ch/beat/beat.env.db.examples:1.3.0r1
docker.idiap.ch/beat/beat.env.cxx:1.0.2
docker.idiap.ch/beat/beat.env.client:1.2.0
)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment