Commit 3dae08ac authored by Samuel GAIST's avatar Samuel GAIST

[test][message_handler] Pre-commit cleanup

parent b1199866
......@@ -35,8 +35,6 @@
import logging
logger = logging.getLogger(__name__)
import unittest
import zmq
import os
......@@ -46,9 +44,6 @@ import numpy as np
from ..execution import MessageHandler
from ..dataformat import DataFormat
from ..inputs import Input
from ..inputs import InputGroup
from ..inputs import InputList
from ..data import RemoteException
from ..data import CachedDataSource
from ..data import RemoteDataSource
......@@ -59,21 +54,22 @@ from .mocks import CrashingDataSource
from . import prefix
#----------------------------------------------------------
logger = logging.getLogger(__name__)
class TestMessageHandlerBase(unittest.TestCase):
# ----------------------------------------------------------
class TestMessageHandlerBase(unittest.TestCase):
def setUp(self):
self.filenames = []
self.data_loader = None
def tearDown(self):
for filename in self.filenames:
basename, ext = os.path.splitext(filename)
filenames = [filename]
filenames += glob.glob(basename + '*')
filenames += glob.glob(basename + "*")
for filename in filenames:
if os.path.exists(filename):
os.unlink(filename)
......@@ -87,32 +83,34 @@ class TestMessageHandlerBase(unittest.TestCase):
self.data_loader = None
def create_data_loader(self, data_sources):
self.client_context = zmq.Context()
self.message_handler = MessageHandler('127.0.0.1', data_sources=data_sources, context=self.client_context)
self.message_handler = MessageHandler(
"127.0.0.1", data_sources=data_sources, context=self.client_context
)
self.message_handler.start()
self.client_socket = self.client_context.socket(zmq.PAIR)
self.client_socket.connect(self.message_handler.address)
self.data_loader = DataLoader('channel')
self.data_loader = DataLoader("channel")
for input_name in data_sources.keys():
data_source = RemoteDataSource()
data_source.setup(self.client_socket, input_name, 'user/single_integer/1', prefix)
data_source.setup(
self.client_socket, input_name, "user/single_integer/1", prefix
)
self.data_loader.add(input_name, data_source)
def writeData(self, start_index=0, end_index=10, step=1, base=0):
testfile = tempfile.NamedTemporaryFile(prefix=__name__, suffix='.data')
testfile.close() # preserve only the name
testfile = tempfile.NamedTemporaryFile(prefix=__name__, suffix=".data")
testfile.close() # preserve only the name
filename = testfile.name
self.filenames.append(filename)
dataformat = DataFormat(prefix, 'user/single_integer/1')
dataformat = DataFormat(prefix, "user/single_integer/1")
self.assertTrue(dataformat.valid)
data_sink = CachedDataSink()
......@@ -140,102 +138,94 @@ class TestMessageHandlerBase(unittest.TestCase):
return cached_file
#----------------------------------------------------------
# ----------------------------------------------------------
class TestOneDataSource(TestMessageHandlerBase):
def setUp(self):
super(TestOneDataSource, self).setUp()
data_sources = {}
data_sources['a'] = self.writeData(start_index=0, end_index=9)
data_sources["a"] = self.writeData(start_index=0, end_index=9)
self.create_data_loader(data_sources)
def test_iteration(self):
self.assertEqual(self.data_loader.count('a'), 10)
self.assertEqual(self.data_loader.count("a"), 10)
for i in range(10):
(result, start, end) = self.data_loader[i]
self.assertEqual(start, i)
self.assertEqual(end, i)
self.assertEqual(result['a'].value, i)
self.assertEqual(result["a"].value, i)
#----------------------------------------------------------
# ----------------------------------------------------------
class TestSameFrequencyDataSources(TestMessageHandlerBase):
def setUp(self):
super(TestSameFrequencyDataSources, self).setUp()
data_sources = {}
data_sources['a'] = self.writeData(start_index=0, end_index=9)
data_sources['b'] = self.writeData(start_index=0, end_index=9, base=10)
data_sources["a"] = self.writeData(start_index=0, end_index=9)
data_sources["b"] = self.writeData(start_index=0, end_index=9, base=10)
self.create_data_loader(data_sources)
def test_iteration(self):
self.assertEqual(self.data_loader.count('a'), 10)
self.assertEqual(self.data_loader.count('b'), 10)
self.assertEqual(self.data_loader.count("a"), 10)
self.assertEqual(self.data_loader.count("b"), 10)
for i in range(10):
(result, start, end) = self.data_loader[i]
self.assertEqual(start, i)
self.assertEqual(end, i)
self.assertEqual(result['a'].value, i)
self.assertEqual(result['b'].value, 10 + i)
self.assertEqual(result["a"].value, i)
self.assertEqual(result["b"].value, 10 + i)
#----------------------------------------------------------
# ----------------------------------------------------------
class TestDifferentFrequenciesDataSources(TestMessageHandlerBase):
def setUp(self):
super(TestDifferentFrequenciesDataSources, self).setUp()
data_sources = {}
data_sources['a'] = self.writeData(start_index=0, end_index=9)
data_sources['b'] = self.writeData(start_index=0, end_index=9, base=10, step=5)
data_sources["a"] = self.writeData(start_index=0, end_index=9)
data_sources["b"] = self.writeData(start_index=0, end_index=9, base=10, step=5)
self.create_data_loader(data_sources)
def test_iteration(self):
self.assertEqual(self.data_loader.count('a'), 10)
self.assertEqual(self.data_loader.count('b'), 2)
self.assertEqual(self.data_loader.count("a"), 10)
self.assertEqual(self.data_loader.count("b"), 2)
for i in range(10):
(result, start, end) = self.data_loader[i]
self.assertEqual(start, i)
self.assertEqual(end, i)
self.assertEqual(result['a'].value, i)
self.assertEqual(result["a"].value, i)
if i < 5:
self.assertEqual(result['b'].value, 10)
self.assertEqual(result["b"].value, 10)
else:
self.assertEqual(result['b'].value, 15)
self.assertEqual(result["b"].value, 15)
#----------------------------------------------------------
# ----------------------------------------------------------
class TestCrashingDataSource(TestMessageHandlerBase):
def setUp(self):
super(TestCrashingDataSource, self).setUp()
data_sources = {}
data_sources['a'] = CrashingDataSource()
data_sources["a"] = CrashingDataSource()
self.create_data_loader(data_sources)
def test_crash(self):
with self.assertRaises(RemoteException):
(result, start, end) = self.data_loader[0]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment