diff --git a/beat/backend/python/data.py b/beat/backend/python/data.py index ad303373f2a557c8ecd0fa8306e0b41b54d3b3b1..80f7547e3cfc9439510192318888a1e7538a4c9b 100644 --- a/beat/backend/python/data.py +++ b/beat/backend/python/data.py @@ -209,6 +209,14 @@ class DataSource(object): def close(self): self.infos = [] + def reset(self): + """Reset the state of the data source + + This shall only clear the current state, not require a new call + to setup the source. + """ + pass + def __del__(self): """Makes sure all resources are released when the object is deleted""" self.close() @@ -267,6 +275,12 @@ class DataSource(object): # ---------------------------------------------------------- +# helper to store file information +# required to be out of the CachedDataSource for pickling reasons +FileInfos = namedtuple( + "FileInfos", ["file_index", "start_index", "end_index", "offset", "size"] +) + class CachedDataSource(DataSource): """Utility class to load data from a file in the cache""" @@ -282,6 +296,29 @@ class CachedDataSource(DataSource): self.current_file_index = None self.unpack = True + def __getstate__(self): + # do not pass open files when being pickled/copied + state = self.__dict__.copy() + + if state["current_file"] is not None: + del state["current_file"] + state["__had_open_file__"] = True + + return state + + def __setstate__(self, state): + # restore the state after being pickled/copied + had_open_file_before_pickle = state.pop("__had_open_file__", False) + + self.__dict__.update(state) + + if had_open_file_before_pickle: + try: + path = self.filenames[self.current_file_index] + self.current_file = open(path, "rb") + except Exception as e: + raise IOError("Could not read `%s': %s" % (path, e)) + def _readHeader(self, file): """Read the header of the provided file""" @@ -420,11 +457,6 @@ class CachedDataSource(DataSource): check_consistency(self.filenames, data_checksum_filenames) - # Load all the needed infos from all the files - FileInfos = namedtuple( - "FileInfos", ["file_index", "start_index", "end_index", "offset", "size"] - ) - for file_index, current_filename in enumerate(self.filenames): try: f = open(current_filename, "rb") @@ -467,11 +499,18 @@ class CachedDataSource(DataSource): return True def close(self): - if self.current_file is not None: - self.current_file.close() + self.reset() super(CachedDataSource, self).close() + def reset(self): + """Rest the current state""" + + if self.current_file is not None: + self.current_file.close() + self.current_file = None + self.current_file_index = None + def __getitem__(self, index): """Retrieve a block of data diff --git a/beat/backend/python/data_loaders.py b/beat/backend/python/data_loaders.py index f8d1db7ffc7a5ec254f309ddc1a6517b28bffb61..429a26d4f29254b61d697f8f3aad8db8a0f37344 100644 --- a/beat/backend/python/data_loaders.py +++ b/beat/backend/python/data_loaders.py @@ -210,6 +210,14 @@ class DataLoader(object): self.data_index_start = -1 # Lower index across all inputs self.data_index_end = -1 # Bigger index across all inputs + def reset(self): + """Reset all the data sources""" + + for infos in self.infos.values(): + data_source = infos.get("data_source") + if data_source: + data_source.reset() + def add(self, input_name, data_source): self.infos[input_name] = dict( data_source=data_source, @@ -302,6 +310,17 @@ class DataLoader(object): return (result, indices[0], indices[1]) + def __getstate__(self): + state = self.__dict__.copy() + + # reset the data cached as its content is not pickable + for infos in state["infos"].values(): + infos["data"] = None + infos["start_index"] = -1 + infos["end_index"] = -1 + + return state + # ---------------------------------------------------------- diff --git a/beat/backend/python/test/prefix/algorithms/autonomous/multiprocess/1.json b/beat/backend/python/test/prefix/algorithms/autonomous/multiprocess/1.json new file mode 100644 index 0000000000000000000000000000000000000000..ae67a982673f361cfd3f2644f2101d49486b9f0a --- /dev/null +++ b/beat/backend/python/test/prefix/algorithms/autonomous/multiprocess/1.json @@ -0,0 +1,28 @@ +{ + "schema_version": 2, + "language": "python", + "api_version": 2, + "type": "autonomous", + "splittable": false, + "groups": [ + { + "inputs": { + "in1": { + "type": "user/single_integer/1" + } + }, + "outputs": { + "out": { + "type": "user/single_integer/1" + } + } + }, + { + "inputs": { + "in2": { + "type": "user/single_integer/1" + } + } + } + ] +} diff --git a/beat/backend/python/test/prefix/algorithms/autonomous/multiprocess/1.py b/beat/backend/python/test/prefix/algorithms/autonomous/multiprocess/1.py new file mode 100755 index 0000000000000000000000000000000000000000..c000d7f0fc82e5deb6462626a36367400efe3b05 --- /dev/null +++ b/beat/backend/python/test/prefix/algorithms/autonomous/multiprocess/1.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +################################################################################### +# # +# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/ # +# Contact: beat.support@idiap.ch # +# # +# Redistribution and use in source and binary forms, with or without # +# modification, are permitted provided that the following conditions are met: # +# # +# 1. Redistributions of source code must retain the above copyright notice, this # +# list of conditions and the following disclaimer. # +# # +# 2. Redistributions in binary form must reproduce the above copyright notice, # +# this list of conditions and the following disclaimer in the documentation # +# and/or other materials provided with the distribution. # +# # +# 3. Neither the name of the copyright holder nor the names of its contributors # +# may be used to endorse or promote products derived from this software without # +# specific prior written permission. # +# # +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # +# # +################################################################################### + +import multiprocessing + + +def foo(queue_in, queue_out, index): + text, data_loader = queue_in.get() + data, _, _ = data_loader[index] + value = data["in1"].value + + queue_out.put("hello " + text + " {}".format(value)) + queue_in.task_done() + + +class Algorithm: + def prepare(self, data_loaders): + data_loader = data_loaders.loaderOf("in2") + + data, _, _ = data_loader[0] + self.offset = data["in2"].value + + return True + + def process(self, data_loaders, outputs): + data_loader = data_loaders.loaderOf("in1") + + # ensure loader has been used before sending it + for i in range(data_loader.count()): + data, _, _ = data_loader[i] + data["in1"].value + + num_thread = data_loader.count() + + queue_in = multiprocessing.JoinableQueue(num_thread) + queue_out = [] + + # Start worker processes + jobs = [] + for i in range(num_thread): + queue_out.append(multiprocessing.Queue()) + p = multiprocessing.Process(target=foo, args=(queue_in, queue_out[i], i)) + jobs.append(p) + p.start() + + # Add None to the queue to kill the workers + for task in range(num_thread): + queue_in.put(("test {}".format(task), data_loader)) + + # Wait for all the tasks to finish + queue_in.join() + + for i in range(data_loader.count()): + data, _, end = data_loader[i] + + outputs["out"].write({"value": data["in1"].value + self.offset}, end) + + return True diff --git a/beat/backend/python/test/prefix/algorithms/autonomous/multiprocess/1.rst b/beat/backend/python/test/prefix/algorithms/autonomous/multiprocess/1.rst new file mode 100644 index 0000000000000000000000000000000000000000..e62c1c42520de03eed49b89f3ed95fce122745de --- /dev/null +++ b/beat/backend/python/test/prefix/algorithms/autonomous/multiprocess/1.rst @@ -0,0 +1 @@ +Test documentation diff --git a/beat/backend/python/test/prefix/algorithms/sequential/multiprocess/1.json b/beat/backend/python/test/prefix/algorithms/sequential/multiprocess/1.json new file mode 100644 index 0000000000000000000000000000000000000000..34eaad0afbd6ff649d39848a0c3eae0451da4a03 --- /dev/null +++ b/beat/backend/python/test/prefix/algorithms/sequential/multiprocess/1.json @@ -0,0 +1,28 @@ +{ + "schema_version": 2, + "language": "python", + "api_version": 2, + "type": "sequential", + "splittable": false, + "groups": [ + { + "inputs": { + "in1": { + "type": "user/single_integer/1" + } + }, + "outputs": { + "out": { + "type": "user/single_integer/1" + } + } + }, + { + "inputs": { + "in2": { + "type": "user/single_integer/1" + } + } + } + ] +} diff --git a/beat/backend/python/test/prefix/algorithms/sequential/multiprocess/1.py b/beat/backend/python/test/prefix/algorithms/sequential/multiprocess/1.py new file mode 100755 index 0000000000000000000000000000000000000000..f5389ebdee740a2303e4359c0770f454a5882e2d --- /dev/null +++ b/beat/backend/python/test/prefix/algorithms/sequential/multiprocess/1.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +################################################################################### +# # +# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/ # +# Contact: beat.support@idiap.ch # +# # +# Redistribution and use in source and binary forms, with or without # +# modification, are permitted provided that the following conditions are met: # +# # +# 1. Redistributions of source code must retain the above copyright notice, this # +# list of conditions and the following disclaimer. # +# # +# 2. Redistributions in binary form must reproduce the above copyright notice, # +# this list of conditions and the following disclaimer in the documentation # +# and/or other materials provided with the distribution. # +# # +# 3. Neither the name of the copyright holder nor the names of its contributors # +# may be used to endorse or promote products derived from this software without # +# specific prior written permission. # +# # +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # +# # +################################################################################### + +import multiprocessing + + +def foo(queue_in, queue_out, index): + text, data_loader = queue_in.get() + + data, _, _ = data_loader[index] + value = data["in2"].value + + queue_out.put("hello " + text + " {}".format(value)) + queue_in.task_done() + + +class Algorithm: + def prepare(self, data_loaders): + data_loader = data_loaders.loaderOf("in2") + + data, _, _ = data_loader[0] + self.offset = data["in2"].value + + return True + + def process(self, inputs, data_loaders, outputs): + data_loader = data_loaders.loaderOf("in2") + + for i in range(data_loader.count()): + data, _, _ = data_loader[i] + data["in2"].value + + num_thread = data_loader.count() + + queue_in = multiprocessing.JoinableQueue(num_thread) + queue_out = [] + + # Start worker processes + jobs = [] + for i in range(num_thread): + queue_out.append(multiprocessing.Queue()) + p = multiprocessing.Process(target=foo, args=(queue_in, queue_out[i], i)) + jobs.append(p) + p.start() + + # Add None to the queue to kill the workers + for task in range(num_thread): + queue_in.put(("test {}".format(task), data_loader)) + + # Wait for all the tasks to finish + queue_in.join() + + outputs["out"].write({"value": inputs["in1"].data.value + self.offset}) + + return True diff --git a/beat/backend/python/test/prefix/algorithms/sequential/multiprocess/1.rst b/beat/backend/python/test/prefix/algorithms/sequential/multiprocess/1.rst new file mode 100644 index 0000000000000000000000000000000000000000..e62c1c42520de03eed49b89f3ed95fce122745de --- /dev/null +++ b/beat/backend/python/test/prefix/algorithms/sequential/multiprocess/1.rst @@ -0,0 +1 @@ +Test documentation diff --git a/beat/backend/python/test/test_algorithm.py b/beat/backend/python/test/test_algorithm.py index 21bad303cafa4003cccf87a5a10d52c49f4562c2..8b90e6fe19dfad498c875f87a4789accb29227df 100644 --- a/beat/backend/python/test/test_algorithm.py +++ b/beat/backend/python/test/test_algorithm.py @@ -1078,6 +1078,36 @@ class TestSequentialAPI_Process(TestExecutionBase): self.assertEqual(data_unit.end, 3) self.assertEqual(data_unit.data.value, 2014) + def test_multiprocess(self): + self.writeData( + "in1", + [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7)], + 1000, + ) + self.writeData("in2", [(0, 1), (2, 3)], 2000) + + (data_loaders, inputs, outputs, data_sink) = self.create_io( + {"group1": ["in1"], "group2": ["in2"]} + ) + + algorithm = Algorithm(prefix, "sequential/multiprocess/1") + runner = algorithm.runner() + + self.assertTrue(runner.setup({"sync": "in2"})) + self.assertTrue(runner.prepare(data_loaders=data_loaders)) + + while inputs.hasMoreData(): + inputs.restricted_access = False + inputs.next() + inputs.restricted_access = True + self.assertTrue( + runner.process( + inputs=inputs, data_loaders=data_loaders, outputs=outputs + ) + ) + + self.assertEqual(len(data_sink.written), 8) + # ---------------------------------------------------------- @@ -1270,3 +1300,25 @@ class TestAutonomousAPI_Process(TestExecutionBase): self.assertEqual(data_unit.start, 3) self.assertEqual(data_unit.end, 3) self.assertEqual(data_unit.data.value, 2014) + + def test_multiprocess(self): + self.writeData( + "in1", + [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7)], + 1000, + ) + self.writeData("in2", [(0, 1), (2, 3)], 2000) + + (data_loaders, outputs, data_sink) = self.create_io( + {"group1": ["in1"], "group2": ["in2"]} + ) + + algorithm = Algorithm(prefix, "autonomous/multiprocess/1") + runner = algorithm.runner() + + self.assertTrue(runner.setup({"sync": "in2"})) + + self.assertTrue(runner.prepare(data_loaders=data_loaders.secondaries())) + self.assertTrue(runner.process(data_loaders=data_loaders, outputs=outputs)) + + self.assertEqual(len(data_sink.written), 8) diff --git a/beat/backend/python/test/test_data.py b/beat/backend/python/test/test_data.py index 320dfdadc1710ad75a010f9db9f4ca99ed5c5d35..4bd14bf49e0900c92e17afdeba1e52fed026a237 100644 --- a/beat/backend/python/test/test_data.py +++ b/beat/backend/python/test/test_data.py @@ -37,6 +37,7 @@ import unittest import os import glob +import pickle # nosec import tempfile import shutil @@ -45,87 +46,70 @@ from ..data import CachedDataSource from ..data import CachedDataSink from ..data import getAllFilenames from ..data import foundSplitRanges -from ..hash import hashFileContents from ..dataformat import DataFormat from ..database import Database from . import prefix -#---------------------------------------------------------- +# ---------------------------------------------------------- class TestMixDataIndices(unittest.TestCase): - def test_one_list(self): - list_of_data_indices = [ - [(0, 2), (3, 4), (5, 6)] - ] + list_of_data_indices = [[(0, 2), (3, 4), (5, 6)]] result = mixDataIndices(list_of_data_indices) self.assertEqual([(0, 2), (3, 4), (5, 6)], result) - def test_two_identical_lists(self): - list_of_data_indices = [ - [(0, 2), (3, 4), (5, 6)], - [(0, 2), (3, 4), (5, 6)], - ] + list_of_data_indices = [[(0, 2), (3, 4), (5, 6)], [(0, 2), (3, 4), (5, 6)]] result = mixDataIndices(list_of_data_indices) self.assertEqual([(0, 2), (3, 4), (5, 6)], result) - def test_two_synchronized_lists(self): - list_of_data_indices = [ - [(0, 2), (3, 4), (5, 6)], - [(0, 4), (5, 6)], - ] + list_of_data_indices = [[(0, 2), (3, 4), (5, 6)], [(0, 4), (5, 6)]] result = mixDataIndices(list_of_data_indices) self.assertEqual([(0, 2), (3, 4), (5, 6)], result) - def test_two_desynchronized_lists(self): - list_of_data_indices = [ - [(0, 2), (3, 4), (5, 6)], - [(0, 1), (2, 4), (5, 6)], - ] + list_of_data_indices = [[(0, 2), (3, 4), (5, 6)], [(0, 1), (2, 4), (5, 6)]] result = mixDataIndices(list_of_data_indices) self.assertEqual([(0, 1), (2, 2), (3, 4), (5, 6)], result) -#---------------------------------------------------------- +# ---------------------------------------------------------- class TestCachedDataBase(unittest.TestCase): - def setUp(self): - testfile = tempfile.NamedTemporaryFile(prefix=__name__, suffix='.data') - testfile.close() # preserve only the name + testfile = tempfile.NamedTemporaryFile(prefix=__name__, suffix=".data") + testfile.close() # preserve only the name self.filename = testfile.name - def tearDown(self): basename, ext = os.path.splitext(self.filename) filenames = [self.filename] - filenames += glob.glob(basename + '*') + filenames += glob.glob(basename + "*") for filename in filenames: if os.path.exists(filename): os.unlink(filename) - def writeData(self, dataformat_name, start_index=0, end_index=10): dataformat = DataFormat(prefix, dataformat_name) self.assertTrue(dataformat.valid) data_sink = CachedDataSink() - self.assertTrue(data_sink.setup(self.filename, dataformat, start_index, end_index)) + self.assertTrue( + data_sink.setup(self.filename, dataformat, start_index, end_index) + ) all_data = [] for i in range(start_index, end_index + 1): @@ -143,84 +127,102 @@ class TestCachedDataBase(unittest.TestCase): return all_data -#---------------------------------------------------------- +# ---------------------------------------------------------- class TestGetAllFilenames(TestCachedDataBase): - def test_one_complete_data_file(self): - self.writeData('user/single_integer/1', 0, 9) + self.writeData("user/single_integer/1", 0, 9) - (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \ - getAllFilenames(self.filename) + ( + data_filenames, + indices_filenames, + data_checksum_filenames, + indices_checksum_filenames, + ) = getAllFilenames(self.filename) self.assertEqual(1, len(data_filenames)) self.assertEqual(1, len(indices_filenames)) self.assertEqual(1, len(data_checksum_filenames)) self.assertEqual(1, len(indices_checksum_filenames)) - def test_three_complete_data_files(self): - self.writeData('user/single_integer/1', 0, 9) - self.writeData('user/single_integer/1', 10, 19) - self.writeData('user/single_integer/1', 20, 29) + self.writeData("user/single_integer/1", 0, 9) + self.writeData("user/single_integer/1", 10, 19) + self.writeData("user/single_integer/1", 20, 29) - (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \ - getAllFilenames(self.filename) + ( + data_filenames, + indices_filenames, + data_checksum_filenames, + indices_checksum_filenames, + ) = getAllFilenames(self.filename) self.assertEqual(3, len(data_filenames)) self.assertEqual(3, len(indices_filenames)) self.assertEqual(3, len(data_checksum_filenames)) self.assertEqual(3, len(indices_checksum_filenames)) - def test_one_partial_data_file(self): - self.writeData('user/single_integer/1', 0, 9) + self.writeData("user/single_integer/1", 0, 9) - (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \ - getAllFilenames(self.filename, 2, 6) + ( + data_filenames, + indices_filenames, + data_checksum_filenames, + indices_checksum_filenames, + ) = getAllFilenames(self.filename, 2, 6) self.assertEqual(1, len(data_filenames)) self.assertEqual(1, len(indices_filenames)) self.assertEqual(1, len(data_checksum_filenames)) self.assertEqual(1, len(indices_checksum_filenames)) - def test_three_partial_data_files_1(self): - self.writeData('user/single_integer/1', 0, 9) - self.writeData('user/single_integer/1', 10, 19) - self.writeData('user/single_integer/1', 20, 29) + self.writeData("user/single_integer/1", 0, 9) + self.writeData("user/single_integer/1", 10, 19) + self.writeData("user/single_integer/1", 20, 29) - (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \ - getAllFilenames(self.filename, 14, 18) + ( + data_filenames, + indices_filenames, + data_checksum_filenames, + indices_checksum_filenames, + ) = getAllFilenames(self.filename, 14, 18) self.assertEqual(1, len(data_filenames)) self.assertEqual(1, len(indices_filenames)) self.assertEqual(1, len(data_checksum_filenames)) self.assertEqual(1, len(indices_checksum_filenames)) - def test_three_partial_data_files_2(self): - self.writeData('user/single_integer/1', 0, 9) - self.writeData('user/single_integer/1', 10, 19) - self.writeData('user/single_integer/1', 20, 29) + self.writeData("user/single_integer/1", 0, 9) + self.writeData("user/single_integer/1", 10, 19) + self.writeData("user/single_integer/1", 20, 29) - (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \ - getAllFilenames(self.filename, 4, 18) + ( + data_filenames, + indices_filenames, + data_checksum_filenames, + indices_checksum_filenames, + ) = getAllFilenames(self.filename, 4, 18) self.assertEqual(2, len(data_filenames)) self.assertEqual(2, len(indices_filenames)) self.assertEqual(2, len(data_checksum_filenames)) self.assertEqual(2, len(indices_checksum_filenames)) - def test_three_partial_data_files_3(self): - self.writeData('user/single_integer/1', 0, 9) - self.writeData('user/single_integer/1', 10, 19) - self.writeData('user/single_integer/1', 20, 29) + self.writeData("user/single_integer/1", 0, 9) + self.writeData("user/single_integer/1", 10, 19) + self.writeData("user/single_integer/1", 20, 29) - (data_filenames, indices_filenames, data_checksum_filenames, indices_checksum_filenames) = \ - getAllFilenames(self.filename, 4, 28) + ( + data_filenames, + indices_filenames, + data_checksum_filenames, + indices_checksum_filenames, + ) = getAllFilenames(self.filename, 4, 28) self.assertEqual(3, len(data_filenames)) self.assertEqual(3, len(indices_filenames)) @@ -228,11 +230,10 @@ class TestGetAllFilenames(TestCachedDataBase): self.assertEqual(3, len(indices_checksum_filenames)) -#---------------------------------------------------------- +# ---------------------------------------------------------- class TestCachedDataSource(TestCachedDataBase): - def check_valid_indices(self, cached_file): for i in range(0, len(cached_file)): (data, start_index, end_index) = cached_file[i] @@ -240,15 +241,15 @@ class TestCachedDataSource(TestCachedDataBase): self.assertEqual(i + cached_file.first_data_index(), start_index) self.assertEqual(i + cached_file.first_data_index(), end_index) - def check_valid_data_indices(self, cached_file): for i in range(0, len(cached_file)): - (data, start_index, end_index) = cached_file.getAtDataIndex(i + cached_file.first_data_index()) + (data, start_index, end_index) = cached_file.getAtDataIndex( + i + cached_file.first_data_index() + ) self.assertTrue(data is not None) self.assertEqual(i + cached_file.first_data_index(), start_index) self.assertEqual(i + cached_file.first_data_index(), end_index) - def check_invalid_indices(self, cached_file): # Invalid indices (data, start_index, end_index) = cached_file[-1] @@ -258,15 +259,18 @@ class TestCachedDataSource(TestCachedDataBase): self.assertTrue(data is None) # Invalid data indices - (data, start_index, end_index) = cached_file.getAtDataIndex(cached_file.first_data_index() - 1) + (data, start_index, end_index) = cached_file.getAtDataIndex( + cached_file.first_data_index() - 1 + ) self.assertTrue(data is None) - (data, start_index, end_index) = cached_file.getAtDataIndex(cached_file.last_data_index() + 1) + (data, start_index, end_index) = cached_file.getAtDataIndex( + cached_file.last_data_index() + 1 + ) self.assertTrue(data is None) - def test_one_complete_data_file(self): - self.writeData('user/single_integer/1', 0, 9) + self.writeData("user/single_integer/1", 0, 9) cached_file = CachedDataSource() cached_file.setup(self.filename, prefix) @@ -277,11 +281,10 @@ class TestCachedDataSource(TestCachedDataBase): self.check_valid_data_indices(cached_file) self.check_invalid_indices(cached_file) - def test_three_complete_data_files(self): - self.writeData('user/single_integer/1', 0, 9) - self.writeData('user/single_integer/1', 10, 19) - self.writeData('user/single_integer/1', 20, 29) + self.writeData("user/single_integer/1", 0, 9) + self.writeData("user/single_integer/1", 10, 19) + self.writeData("user/single_integer/1", 20, 29) cached_file = CachedDataSource() cached_file.setup(self.filename, prefix) @@ -292,9 +295,8 @@ class TestCachedDataSource(TestCachedDataBase): self.check_valid_data_indices(cached_file) self.check_invalid_indices(cached_file) - def test_one_partial_data_file(self): - self.writeData('user/single_integer/1', 0, 9) + self.writeData("user/single_integer/1", 0, 9) cached_file = CachedDataSource() cached_file.setup(self.filename, prefix, 2, 6) @@ -305,11 +307,10 @@ class TestCachedDataSource(TestCachedDataBase): self.check_valid_data_indices(cached_file) self.check_invalid_indices(cached_file) - def test_three_partial_data_files_1(self): - self.writeData('user/single_integer/1', 0, 9) - self.writeData('user/single_integer/1', 10, 19) - self.writeData('user/single_integer/1', 20, 29) + self.writeData("user/single_integer/1", 0, 9) + self.writeData("user/single_integer/1", 10, 19) + self.writeData("user/single_integer/1", 20, 29) cached_file = CachedDataSource() cached_file.setup(self.filename, prefix, 14, 18) @@ -320,11 +321,10 @@ class TestCachedDataSource(TestCachedDataBase): self.check_valid_data_indices(cached_file) self.check_invalid_indices(cached_file) - def test_three_partial_data_files_2(self): - self.writeData('user/single_integer/1', 0, 9) - self.writeData('user/single_integer/1', 10, 19) - self.writeData('user/single_integer/1', 20, 29) + self.writeData("user/single_integer/1", 0, 9) + self.writeData("user/single_integer/1", 10, 19) + self.writeData("user/single_integer/1", 20, 29) cached_file = CachedDataSource() cached_file.setup(self.filename, prefix, 4, 18) @@ -335,11 +335,10 @@ class TestCachedDataSource(TestCachedDataBase): self.check_valid_data_indices(cached_file) self.check_invalid_indices(cached_file) - def test_three_partial_data_files_3(self): - self.writeData('user/single_integer/1', 0, 9) - self.writeData('user/single_integer/1', 10, 19) - self.writeData('user/single_integer/1', 20, 29) + self.writeData("user/single_integer/1", 0, 9) + self.writeData("user/single_integer/1", 10, 19) + self.writeData("user/single_integer/1", 20, 29) cached_file = CachedDataSource() cached_file.setup(self.filename, prefix, 4, 28) @@ -350,20 +349,52 @@ class TestCachedDataSource(TestCachedDataBase): self.check_valid_data_indices(cached_file) self.check_invalid_indices(cached_file) + def test_reset(self): + self.writeData("user/single_integer/1", 0, 9) -#---------------------------------------------------------- + cached_source = CachedDataSource() + cached_source.setup(self.filename, prefix) + _, _, _ = cached_source[0] + self.assertIsNotNone(cached_source.current_file_index) + self.assertIsNotNone(cached_source.current_file) + cached_source.reset() + self.assertIsNone(cached_source.current_file_index) + self.assertIsNone(cached_source.current_file) + def test_picklability(self): + self.writeData("user/single_integer/1", 0, 9) + + cached_source = CachedDataSource() + cached_source.setup(self.filename, prefix) + + # test pickle before accessing a file + cached_source2 = pickle.loads(pickle.dumps(cached_source)) # nosec + data, start, end = cached_source[0] + data2, start2, end2 = cached_source2[0] + self.assertEqual(data.value, data2.value) + self.assertEqual(start, start2) + self.assertEqual(end, end2) + + # access one file and try again + data, start, end = cached_source[0] + cached_source2 = pickle.loads(pickle.dumps(cached_source)) # nosec + + data2, start2, end2 = cached_source2[0] + self.assertEqual(data.value, data2.value) + self.assertEqual(start, start2) + self.assertEqual(end, end2) + + +# ---------------------------------------------------------- -class TestDatabaseOutputDataSource(unittest.TestCase): +class TestDatabaseOutputDataSource(unittest.TestCase): def setUp(self): self.cache_root = tempfile.mkdtemp(prefix=__name__) - def tearDown(self): shutil.rmtree(self.cache_root) - def check_valid_indices(self, data_source): for i in range(0, len(data_source)): (data, start_index, end_index) = data_source[i] @@ -371,15 +402,15 @@ class TestDatabaseOutputDataSource(unittest.TestCase): self.assertEqual(i + data_source.first_data_index(), start_index) self.assertEqual(i + data_source.first_data_index(), end_index) - def check_valid_data_indices(self, data_source): for i in range(0, len(data_source)): - (data, start_index, end_index) = data_source.getAtDataIndex(i + data_source.first_data_index()) + (data, start_index, end_index) = data_source.getAtDataIndex( + i + data_source.first_data_index() + ) self.assertTrue(data is not None) self.assertEqual(i + data_source.first_data_index(), start_index) self.assertEqual(i + data_source.first_data_index(), end_index) - def check_invalid_indices(self, data_source): # Invalid indices (data, start_index, end_index) = data_source[-1] @@ -389,20 +420,23 @@ class TestDatabaseOutputDataSource(unittest.TestCase): self.assertTrue(data is None) # Invalid data indices - (data, start_index, end_index) = data_source.getAtDataIndex(data_source.first_data_index() - 1) + (data, start_index, end_index) = data_source.getAtDataIndex( + data_source.first_data_index() - 1 + ) self.assertTrue(data is None) - (data, start_index, end_index) = data_source.getAtDataIndex(data_source.last_data_index() + 1) + (data, start_index, end_index) = data_source.getAtDataIndex( + data_source.last_data_index() + 1 + ) self.assertTrue(data is None) - def test(self): - db = Database(prefix, 'integers_db/1') + db = Database(prefix, "integers_db/1") self.assertTrue(db.valid) - view = db.view('double', 'double') - view.index(os.path.join(self.cache_root, 'data.db')) - view.setup(os.path.join(self.cache_root, 'data.db'), pack=False) + view = db.view("double", "double") + view.index(os.path.join(self.cache_root, "data.db")) + view.setup(os.path.join(self.cache_root, "data.db"), pack=False) self.assertTrue(view.data_sources is not None) self.assertEqual(len(view.data_sources), 3) @@ -415,35 +449,32 @@ class TestDatabaseOutputDataSource(unittest.TestCase): self.check_invalid_indices(data_source) -#---------------------------------------------------------- +# ---------------------------------------------------------- class TestDataSink(TestCachedDataBase): - def test_creation(self): - dataformat = DataFormat(prefix, 'user/single_integer/1') + dataformat = DataFormat(prefix, "user/single_integer/1") self.assertTrue(dataformat.valid) data_sink = CachedDataSink() self.assertTrue(data_sink.setup(self.filename, dataformat, 0, 10)) -#---------------------------------------------------------- +# ---------------------------------------------------------- class TestFoundSplitRanges(unittest.TestCase): - def test_2_splits(self): - l = [[0,2,4,6,8,10,12], [0,3,6,9,12]] + splits = [[0, 2, 4, 6, 8, 10, 12], [0, 3, 6, 9, 12]] n_split = 2 ref = [(0, 5), (6, 11)] - res = foundSplitRanges(l, n_split) + res = foundSplitRanges(splits, n_split) self.assertEqual(res, ref) - def test_5_splits(self): - l = [[0,2,4,6,8,10,12,15], [0,3,6,9,12,15]] + splits = [[0, 2, 4, 6, 8, 10, 12, 15], [0, 3, 6, 9, 12, 15]] n_split = 5 ref = [(0, 5), (6, 11), (12, 14)] - res = foundSplitRanges(l, n_split) + res = foundSplitRanges(splits, n_split) self.assertEqual(res, ref) diff --git a/beat/backend/python/test/test_data_loaders.py b/beat/backend/python/test/test_data_loaders.py index f8e3cfc4811e65f625db3ea2af00d3b82d770b2e..5c7ad1ec81aad5ddb3b188e6401adc6be8a8ed33 100644 --- a/beat/backend/python/test/test_data_loaders.py +++ b/beat/backend/python/test/test_data_loaders.py @@ -49,36 +49,38 @@ from ..data import CachedDataSource from . import prefix -#---------------------------------------------------------- +# ---------------------------------------------------------- class DataLoaderBaseTest(unittest.TestCase): - def setUp(self): self.filenames = {} - def tearDown(self): for f in self.filenames.values(): basename, ext = os.path.splitext(f) filenames = [f] - filenames += glob.glob(basename + '*') + filenames += glob.glob(basename + "*") for filename in filenames: if os.path.exists(filename): os.unlink(filename) - def writeData(self, input_name, indices, start_value): - testfile = tempfile.NamedTemporaryFile(prefix=__name__ + input_name, suffix='.data') - testfile.close() # preserve only the name + testfile = tempfile.NamedTemporaryFile( + prefix=__name__ + input_name, suffix=".data" + ) + testfile.close() # preserve only the name self.filenames[input_name] = testfile.name - dataformat = DataFormat(prefix, 'user/single_integer/1') + dataformat = DataFormat(prefix, "user/single_integer/1") self.assertTrue(dataformat.valid) data_sink = CachedDataSink() - self.assertTrue(data_sink.setup(self.filenames[input_name], dataformat, - indices[0][0], indices[-1][1])) + self.assertTrue( + data_sink.setup( + self.filenames[input_name], dataformat, indices[0][0], indices[-1][1] + ) + ) for i in indices: data = dataformat.type() @@ -93,49 +95,44 @@ class DataLoaderBaseTest(unittest.TestCase): del data_sink -#---------------------------------------------------------- +# ---------------------------------------------------------- class DataLoaderTest(DataLoaderBaseTest): - def test_creation(self): - data_loader = DataLoader('channel1') + data_loader = DataLoader("channel1") - self.assertEqual(data_loader.channel, 'channel1') + self.assertEqual(data_loader.channel, "channel1") self.assertEqual(data_loader.nb_data_units, 0) self.assertEqual(data_loader.data_index_start, -1) self.assertEqual(data_loader.data_index_end, -1) self.assertEqual(data_loader.count(), 0) - def test_empty(self): - data_loader = DataLoader('channel1') + data_loader = DataLoader("channel1") self.assertEqual(data_loader.count(), 0) self.assertEqual(data_loader[0], (None, None, None)) - self.assertTrue(data_loader.view('unknown', 0) is None) - + self.assertTrue(data_loader.view("unknown", 0) is None) def test_one_input(self): # Setup - self.writeData('input1', [(0, 0), (1, 1), (2, 2)], 1000) + self.writeData("input1", [(0, 0), (1, 1), (2, 2)], 1000) - data_loader = DataLoader('channel1') + data_loader = DataLoader("channel1") cached_file = CachedDataSource() - cached_file.setup(self.filenames['input1'], prefix) - data_loader.add('input1', cached_file) - + cached_file.setup(self.filenames["input1"], prefix) + data_loader.add("input1", cached_file) # Global checks self.assertEqual(data_loader.count(), 3) - self.assertEqual(data_loader.count('input1'), 3) + self.assertEqual(data_loader.count("input1"), 3) self.assertEqual(data_loader.data_index_start, 0) self.assertEqual(data_loader.data_index_end, 2) - # Indexing (data, start, end) = data_loader[-1] self.assertTrue(data is None) @@ -144,34 +141,32 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 0) self.assertEqual(end, 0) - self.assertEqual(data['input1'].value, 1000) + self.assertEqual(data["input1"].value, 1000) (data, start, end) = data_loader[1] self.assertTrue(data is not None) self.assertEqual(start, 1) self.assertEqual(end, 1) - self.assertEqual(data['input1'].value, 1001) + self.assertEqual(data["input1"].value, 1001) (data, start, end) = data_loader[2] self.assertTrue(data is not None) self.assertEqual(start, 2) self.assertEqual(end, 2) - self.assertEqual(data['input1'].value, 1002) + self.assertEqual(data["input1"].value, 1002) (data, start, end) = data_loader[3] self.assertTrue(data is None) - # View 'input1', index -1 - view = data_loader.view('input1', -1) + view = data_loader.view("input1", -1) self.assertTrue(view is None) - # View 'input1', index 0 - view = data_loader.view('input1', 0) + view = data_loader.view("input1", 0) self.assertEqual(view.count(), 1) - self.assertEqual(view.count('input1'), 1) + self.assertEqual(view.count("input1"), 1) self.assertEqual(view.data_index_start, 0) self.assertEqual(view.data_index_end, 0) @@ -183,17 +178,16 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 0) self.assertEqual(end, 0) - self.assertEqual(data['input1'].value, 1000) + self.assertEqual(data["input1"].value, 1000) (data, start, end) = view[1] self.assertTrue(data is None) - # View 'input1', index 1 - view = data_loader.view('input1', 1) + view = data_loader.view("input1", 1) self.assertEqual(view.count(), 1) - self.assertEqual(view.count('input1'), 1) + self.assertEqual(view.count("input1"), 1) self.assertEqual(view.data_index_start, 1) self.assertEqual(view.data_index_end, 1) @@ -205,17 +199,16 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 1) self.assertEqual(end, 1) - self.assertEqual(data['input1'].value, 1001) + self.assertEqual(data["input1"].value, 1001) (data, start, end) = view[1] self.assertTrue(data is None) - # View 'input1', index 2 - view = data_loader.view('input1', 2) + view = data_loader.view("input1", 2) self.assertEqual(view.count(), 1) - self.assertEqual(view.count('input1'), 1) + self.assertEqual(view.count("input1"), 1) self.assertEqual(view.data_index_start, 2) self.assertEqual(view.data_index_end, 2) @@ -227,43 +220,39 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 2) self.assertEqual(end, 2) - self.assertEqual(data['input1'].value, 1002) + self.assertEqual(data["input1"].value, 1002) (data, start, end) = view[1] self.assertTrue(data is None) - # View 'input1', index 3 - view = data_loader.view('input1', 3) + view = data_loader.view("input1", 3) self.assertTrue(view is None) - def test_two_synchronized_inputs(self): # Setup - self.writeData('input1', [(0, 0), (1, 1), (2, 2), (3, 3)], 1000) - self.writeData('input2', [(0, 1), (2, 3)], 2000) + self.writeData("input1", [(0, 0), (1, 1), (2, 2), (3, 3)], 1000) + self.writeData("input2", [(0, 1), (2, 3)], 2000) - data_loader = DataLoader('channel1') + data_loader = DataLoader("channel1") cached_file = CachedDataSource() - cached_file.setup(self.filenames['input1'], prefix) - data_loader.add('input1', cached_file) + cached_file.setup(self.filenames["input1"], prefix) + data_loader.add("input1", cached_file) cached_file = CachedDataSource() - cached_file.setup(self.filenames['input2'], prefix) - data_loader.add('input2', cached_file) - + cached_file.setup(self.filenames["input2"], prefix) + data_loader.add("input2", cached_file) # Global checks self.assertEqual(data_loader.count(), 4) - self.assertEqual(data_loader.count('input1'), 4) - self.assertEqual(data_loader.count('input2'), 2) + self.assertEqual(data_loader.count("input1"), 4) + self.assertEqual(data_loader.count("input2"), 2) self.assertEqual(data_loader.data_index_start, 0) self.assertEqual(data_loader.data_index_end, 3) - # Indexing (data, start, end) = data_loader[-1] self.assertTrue(data is None) @@ -272,45 +261,43 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 0) self.assertEqual(end, 0) - self.assertEqual(data['input1'].value, 1000) - self.assertEqual(data['input2'].value, 2000) + self.assertEqual(data["input1"].value, 1000) + self.assertEqual(data["input2"].value, 2000) (data, start, end) = data_loader[1] self.assertTrue(data is not None) self.assertEqual(start, 1) self.assertEqual(end, 1) - self.assertEqual(data['input1'].value, 1001) - self.assertEqual(data['input2'].value, 2000) + self.assertEqual(data["input1"].value, 1001) + self.assertEqual(data["input2"].value, 2000) (data, start, end) = data_loader[2] self.assertTrue(data is not None) self.assertEqual(start, 2) self.assertEqual(end, 2) - self.assertEqual(data['input1'].value, 1002) - self.assertEqual(data['input2'].value, 2002) + self.assertEqual(data["input1"].value, 1002) + self.assertEqual(data["input2"].value, 2002) (data, start, end) = data_loader[3] self.assertTrue(data is not None) self.assertEqual(start, 3) self.assertEqual(end, 3) - self.assertEqual(data['input1'].value, 1003) - self.assertEqual(data['input2'].value, 2002) + self.assertEqual(data["input1"].value, 1003) + self.assertEqual(data["input2"].value, 2002) (data, start, end) = data_loader[4] self.assertTrue(data is None) - # View 'input1', index -1 - view = data_loader.view('input1', -1) + view = data_loader.view("input1", -1) self.assertTrue(view is None) - # View 'input1', index 0 - view = data_loader.view('input1', 0) + view = data_loader.view("input1", 0) self.assertEqual(view.count(), 1) - self.assertEqual(view.count('input1'), 1) - self.assertEqual(view.count('input2'), 1) + self.assertEqual(view.count("input1"), 1) + self.assertEqual(view.count("input2"), 1) self.assertEqual(view.data_index_start, 0) self.assertEqual(view.data_index_end, 0) @@ -322,19 +309,18 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 0) self.assertEqual(end, 0) - self.assertEqual(data['input1'].value, 1000) - self.assertEqual(data['input2'].value, 2000) + self.assertEqual(data["input1"].value, 1000) + self.assertEqual(data["input2"].value, 2000) (data, start, end) = view[1] self.assertTrue(data is None) - # View 'input1', index 1 - view = data_loader.view('input1', 1) + view = data_loader.view("input1", 1) self.assertEqual(view.count(), 1) - self.assertEqual(view.count('input1'), 1) - self.assertEqual(view.count('input2'), 1) + self.assertEqual(view.count("input1"), 1) + self.assertEqual(view.count("input2"), 1) self.assertEqual(view.data_index_start, 1) self.assertEqual(view.data_index_end, 1) @@ -346,19 +332,18 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 1) self.assertEqual(end, 1) - self.assertEqual(data['input1'].value, 1001) - self.assertEqual(data['input2'].value, 2000) + self.assertEqual(data["input1"].value, 1001) + self.assertEqual(data["input2"].value, 2000) (data, start, end) = view[1] self.assertTrue(data is None) - # View 'input1', index 2 - view = data_loader.view('input1', 2) + view = data_loader.view("input1", 2) self.assertEqual(view.count(), 1) - self.assertEqual(view.count('input1'), 1) - self.assertEqual(view.count('input2'), 1) + self.assertEqual(view.count("input1"), 1) + self.assertEqual(view.count("input2"), 1) self.assertEqual(view.data_index_start, 2) self.assertEqual(view.data_index_end, 2) @@ -370,19 +355,18 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 2) self.assertEqual(end, 2) - self.assertEqual(data['input1'].value, 1002) - self.assertEqual(data['input2'].value, 2002) + self.assertEqual(data["input1"].value, 1002) + self.assertEqual(data["input2"].value, 2002) (data, start, end) = view[1] self.assertTrue(data is None) - # View 'input1', index 3 - view = data_loader.view('input1', 3) + view = data_loader.view("input1", 3) self.assertEqual(view.count(), 1) - self.assertEqual(view.count('input1'), 1) - self.assertEqual(view.count('input2'), 1) + self.assertEqual(view.count("input1"), 1) + self.assertEqual(view.count("input2"), 1) self.assertEqual(view.data_index_start, 3) self.assertEqual(view.data_index_end, 3) @@ -394,29 +378,26 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 3) self.assertEqual(end, 3) - self.assertEqual(data['input1'].value, 1003) - self.assertEqual(data['input2'].value, 2002) + self.assertEqual(data["input1"].value, 1003) + self.assertEqual(data["input2"].value, 2002) (data, start, end) = view[1] self.assertTrue(data is None) - # View 'input1', index 4 - view = data_loader.view('input1', 4) + view = data_loader.view("input1", 4) self.assertTrue(view is None) - # View 'input2', index -1 - view = data_loader.view('input2', -1) + view = data_loader.view("input2", -1) self.assertTrue(view is None) - # View 'input2', index 0 - view = data_loader.view('input2', 0) + view = data_loader.view("input2", 0) self.assertEqual(view.count(), 2) - self.assertEqual(view.count('input1'), 2) - self.assertEqual(view.count('input2'), 1) + self.assertEqual(view.count("input1"), 2) + self.assertEqual(view.count("input2"), 1) self.assertEqual(view.data_index_start, 0) self.assertEqual(view.data_index_end, 1) @@ -428,26 +409,25 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 0) self.assertEqual(end, 0) - self.assertEqual(data['input1'].value, 1000) - self.assertEqual(data['input2'].value, 2000) + self.assertEqual(data["input1"].value, 1000) + self.assertEqual(data["input2"].value, 2000) (data, start, end) = view[1] self.assertTrue(data is not None) self.assertEqual(start, 1) self.assertEqual(end, 1) - self.assertEqual(data['input1'].value, 1001) - self.assertEqual(data['input2'].value, 2000) + self.assertEqual(data["input1"].value, 1001) + self.assertEqual(data["input2"].value, 2000) (data, start, end) = view[2] self.assertTrue(data is None) - # View 'input2', index 1 - view = data_loader.view('input2', 1) + view = data_loader.view("input2", 1) self.assertEqual(view.count(), 2) - self.assertEqual(view.count('input1'), 2) - self.assertEqual(view.count('input2'), 1) + self.assertEqual(view.count("input1"), 2) + self.assertEqual(view.count("input2"), 1) self.assertEqual(view.data_index_start, 2) self.assertEqual(view.data_index_end, 3) @@ -459,51 +439,47 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 2) self.assertEqual(end, 2) - self.assertEqual(data['input1'].value, 1002) - self.assertEqual(data['input2'].value, 2002) + self.assertEqual(data["input1"].value, 1002) + self.assertEqual(data["input2"].value, 2002) (data, start, end) = view[1] self.assertTrue(data is not None) self.assertEqual(start, 3) self.assertEqual(end, 3) - self.assertEqual(data['input1'].value, 1003) - self.assertEqual(data['input2'].value, 2002) + self.assertEqual(data["input1"].value, 1003) + self.assertEqual(data["input2"].value, 2002) (data, start, end) = view[2] self.assertTrue(data is None) - # View 'input2', index 2 - view = data_loader.view('input2', 2) + view = data_loader.view("input2", 2) self.assertTrue(view is None) - def test_two_desynchronized_inputs(self): # Setup - self.writeData('input1', [(0, 2), (3, 3)], 1000) - self.writeData('input2', [(0, 1), (2, 3)], 2000) + self.writeData("input1", [(0, 2), (3, 3)], 1000) + self.writeData("input2", [(0, 1), (2, 3)], 2000) - data_loader = DataLoader('channel1') + data_loader = DataLoader("channel1") cached_file = CachedDataSource() - cached_file.setup(self.filenames['input1'], prefix) - data_loader.add('input1', cached_file) + cached_file.setup(self.filenames["input1"], prefix) + data_loader.add("input1", cached_file) cached_file = CachedDataSource() - cached_file.setup(self.filenames['input2'], prefix) - data_loader.add('input2', cached_file) - + cached_file.setup(self.filenames["input2"], prefix) + data_loader.add("input2", cached_file) # Global checks self.assertEqual(data_loader.count(), 3) - self.assertEqual(data_loader.count('input1'), 2) - self.assertEqual(data_loader.count('input2'), 2) + self.assertEqual(data_loader.count("input1"), 2) + self.assertEqual(data_loader.count("input2"), 2) self.assertEqual(data_loader.data_index_start, 0) self.assertEqual(data_loader.data_index_end, 3) - # Indexing (data, start, end) = data_loader[-1] self.assertTrue(data is None) @@ -512,38 +488,36 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 0) self.assertEqual(end, 1) - self.assertEqual(data['input1'].value, 1000) - self.assertEqual(data['input2'].value, 2000) + self.assertEqual(data["input1"].value, 1000) + self.assertEqual(data["input2"].value, 2000) (data, start, end) = data_loader[1] self.assertTrue(data is not None) self.assertEqual(start, 2) self.assertEqual(end, 2) - self.assertEqual(data['input1'].value, 1000) - self.assertEqual(data['input2'].value, 2002) + self.assertEqual(data["input1"].value, 1000) + self.assertEqual(data["input2"].value, 2002) (data, start, end) = data_loader[2] self.assertTrue(data is not None) self.assertEqual(start, 3) self.assertEqual(end, 3) - self.assertEqual(data['input1'].value, 1003) - self.assertEqual(data['input2'].value, 2002) + self.assertEqual(data["input1"].value, 1003) + self.assertEqual(data["input2"].value, 2002) (data, start, end) = data_loader[3] self.assertTrue(data is None) - # View 'input1', index -1 - view = data_loader.view('input1', -1) + view = data_loader.view("input1", -1) self.assertTrue(view is None) - # View 'input1', index 0 - view = data_loader.view('input1', 0) + view = data_loader.view("input1", 0) self.assertEqual(view.count(), 2) - self.assertEqual(view.count('input1'), 1) - self.assertEqual(view.count('input2'), 2) + self.assertEqual(view.count("input1"), 1) + self.assertEqual(view.count("input2"), 2) self.assertEqual(view.data_index_start, 0) self.assertEqual(view.data_index_end, 2) @@ -555,26 +529,25 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 0) self.assertEqual(end, 1) - self.assertEqual(data['input1'].value, 1000) - self.assertEqual(data['input2'].value, 2000) + self.assertEqual(data["input1"].value, 1000) + self.assertEqual(data["input2"].value, 2000) (data, start, end) = view[1] self.assertTrue(data is not None) self.assertEqual(start, 2) self.assertEqual(end, 2) - self.assertEqual(data['input1'].value, 1000) - self.assertEqual(data['input2'].value, 2002) + self.assertEqual(data["input1"].value, 1000) + self.assertEqual(data["input2"].value, 2002) (data, start, end) = view[2] self.assertTrue(data is None) - # View 'input1', index 1 - view = data_loader.view('input1', 1) + view = data_loader.view("input1", 1) self.assertEqual(view.count(), 1) - self.assertEqual(view.count('input1'), 1) - self.assertEqual(view.count('input2'), 1) + self.assertEqual(view.count("input1"), 1) + self.assertEqual(view.count("input2"), 1) self.assertEqual(view.data_index_start, 3) self.assertEqual(view.data_index_end, 3) @@ -586,29 +559,26 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 3) self.assertEqual(end, 3) - self.assertEqual(data['input1'].value, 1003) - self.assertEqual(data['input2'].value, 2002) + self.assertEqual(data["input1"].value, 1003) + self.assertEqual(data["input2"].value, 2002) (data, start, end) = view[1] self.assertTrue(data is None) - # View 'input1', index 2 - view = data_loader.view('input1', 2) + view = data_loader.view("input1", 2) self.assertTrue(view is None) - # View 'input2', index -1 - view = data_loader.view('input2', -1) + view = data_loader.view("input2", -1) self.assertTrue(view is None) - # View 'input2', index 0 - view = data_loader.view('input2', 0) + view = data_loader.view("input2", 0) self.assertEqual(view.count(), 1) - self.assertEqual(view.count('input1'), 1) - self.assertEqual(view.count('input2'), 1) + self.assertEqual(view.count("input1"), 1) + self.assertEqual(view.count("input2"), 1) self.assertEqual(view.data_index_start, 0) self.assertEqual(view.data_index_end, 1) @@ -620,19 +590,18 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 0) self.assertEqual(end, 1) - self.assertEqual(data['input1'].value, 1000) - self.assertEqual(data['input2'].value, 2000) + self.assertEqual(data["input1"].value, 1000) + self.assertEqual(data["input2"].value, 2000) (data, start, end) = view[1] self.assertTrue(data is None) - # View 'input2', index 1 - view = data_loader.view('input2', 1) + view = data_loader.view("input2", 1) self.assertEqual(view.count(), 2) - self.assertEqual(view.count('input1'), 2) - self.assertEqual(view.count('input2'), 1) + self.assertEqual(view.count("input1"), 2) + self.assertEqual(view.count("input2"), 1) self.assertEqual(view.data_index_start, 2) self.assertEqual(view.data_index_end, 3) @@ -644,60 +613,73 @@ class DataLoaderTest(DataLoaderBaseTest): self.assertTrue(data is not None) self.assertEqual(start, 2) self.assertEqual(end, 2) - self.assertEqual(data['input1'].value, 1000) - self.assertEqual(data['input2'].value, 2002) + self.assertEqual(data["input1"].value, 1000) + self.assertEqual(data["input2"].value, 2002) (data, start, end) = view[1] self.assertTrue(data is not None) self.assertEqual(start, 3) self.assertEqual(end, 3) - self.assertEqual(data['input1'].value, 1003) - self.assertEqual(data['input2'].value, 2002) + self.assertEqual(data["input1"].value, 1003) + self.assertEqual(data["input2"].value, 2002) (data, start, end) = view[2] self.assertTrue(data is None) - # View 'input2', index 2 - view = data_loader.view('input2', 2) + view = data_loader.view("input2", 2) self.assertTrue(view is None) + def test_reset(self): + # Setup + input_name = "input1" + self.writeData(input_name, [(0, 0), (1, 1), (2, 2)], 1000) -#---------------------------------------------------------- + data_loader = DataLoader("channel1") + cached_file = CachedDataSource() + cached_file.setup(self.filenames[input_name], prefix) + data_loader.add(input_name, cached_file) + + _, _, _ = data_loader[0] + cached_source = data_loader.infos[input_name]["data_source"] + self.assertIsNotNone(cached_source.current_file_index) + self.assertIsNotNone(cached_source.current_file) + data_loader.reset() + self.assertIsNone(cached_source.current_file_index) + self.assertIsNone(cached_source.current_file) -class DataLoaderListTest(DataLoaderBaseTest): +# ---------------------------------------------------------- + + +class DataLoaderListTest(DataLoaderBaseTest): def test_creation(self): data_loaders = DataLoaderList() self.assertTrue(data_loaders.main_loader is None) self.assertEqual(len(data_loaders), 0) - def test_list_unkown_loader_retrieval(self): data_loaders = DataLoaderList() - self.assertTrue(data_loaders['unknown'] is None) - + self.assertTrue(data_loaders["unknown"] is None) def test_list_invalid_index_retrieval(self): data_loaders = DataLoaderList() self.assertTrue(data_loaders[10] is None) - def test_list_loader_of_unknown_input_retrieval(self): data_loaders = DataLoaderList() - self.assertTrue(data_loaders.loaderOf('unknown') is None) - + self.assertTrue(data_loaders.loaderOf("unknown") is None) def test_list_one_loader_one_input(self): - self.writeData('input1', [(0, 0), (1, 1), (2, 2)], 1000) + self.writeData("input1", [(0, 0), (1, 1), (2, 2)], 1000) - data_loader = DataLoader('channel1') + data_loader = DataLoader("channel1") cached_file = CachedDataSource() - cached_file.setup(self.filenames['input1'], prefix) - data_loader.add('input1', cached_file) + cached_file.setup(self.filenames["input1"], prefix) + data_loader.add("input1", cached_file) data_loaders = DataLoaderList() data_loaders.add(data_loader) @@ -705,25 +687,24 @@ class DataLoaderListTest(DataLoaderBaseTest): self.assertEqual(data_loaders.main_loader, data_loader) self.assertEqual(len(data_loaders), 1) - self.assertEqual(data_loaders['channel1'], data_loader) + self.assertEqual(data_loaders["channel1"], data_loader) self.assertEqual(data_loaders[0], data_loader) - self.assertEqual(data_loaders.loaderOf('input1'), data_loader) - + self.assertEqual(data_loaders.loaderOf("input1"), data_loader) def test_list_one_loader_two_inputs(self): - self.writeData('input1', [(0, 0), (1, 1), (2, 2)], 1000) - self.writeData('input2', [(0, 2)], 2000) + self.writeData("input1", [(0, 0), (1, 1), (2, 2)], 1000) + self.writeData("input2", [(0, 2)], 2000) - data_loader = DataLoader('channel1') + data_loader = DataLoader("channel1") cached_file = CachedDataSource() - cached_file.setup(self.filenames['input1'], prefix) - data_loader.add('input1', cached_file) + cached_file.setup(self.filenames["input1"], prefix) + data_loader.add("input1", cached_file) cached_file = CachedDataSource() - cached_file.setup(self.filenames['input2'], prefix) - data_loader.add('input2', cached_file) + cached_file.setup(self.filenames["input2"], prefix) + data_loader.add("input2", cached_file) data_loaders = DataLoaderList() data_loaders.add(data_loader) @@ -731,33 +712,32 @@ class DataLoaderListTest(DataLoaderBaseTest): self.assertEqual(data_loaders.main_loader, data_loader) self.assertEqual(len(data_loaders), 1) - self.assertEqual(data_loaders['channel1'], data_loader) + self.assertEqual(data_loaders["channel1"], data_loader) self.assertEqual(data_loaders[0], data_loader) - self.assertEqual(data_loaders.loaderOf('input1'), data_loader) - self.assertEqual(data_loaders.loaderOf('input2'), data_loader) - + self.assertEqual(data_loaders.loaderOf("input1"), data_loader) + self.assertEqual(data_loaders.loaderOf("input2"), data_loader) def test_list_two_loaders_three_inputs(self): - self.writeData('input1', [(0, 0), (1, 1), (2, 2)], 1000) - self.writeData('input2', [(0, 2)], 2000) - self.writeData('input3', [(0, 1), (2, 2)], 3000) + self.writeData("input1", [(0, 0), (1, 1), (2, 2)], 1000) + self.writeData("input2", [(0, 2)], 2000) + self.writeData("input3", [(0, 1), (2, 2)], 3000) - data_loader1 = DataLoader('channel1') + data_loader1 = DataLoader("channel1") cached_file = CachedDataSource() - cached_file.setup(self.filenames['input1'], prefix) - data_loader1.add('input1', cached_file) + cached_file.setup(self.filenames["input1"], prefix) + data_loader1.add("input1", cached_file) cached_file = CachedDataSource() - cached_file.setup(self.filenames['input2'], prefix) - data_loader1.add('input2', cached_file) + cached_file.setup(self.filenames["input2"], prefix) + data_loader1.add("input2", cached_file) - data_loader2 = DataLoader('channel2') + data_loader2 = DataLoader("channel2") cached_file = CachedDataSource() - cached_file.setup(self.filenames['input3'], prefix) - data_loader2.add('input3', cached_file) + cached_file.setup(self.filenames["input3"], prefix) + data_loader2.add("input3", cached_file) data_loaders = DataLoaderList() data_loaders.add(data_loader1) @@ -766,12 +746,12 @@ class DataLoaderListTest(DataLoaderBaseTest): self.assertEqual(data_loaders.main_loader, data_loader1) self.assertEqual(len(data_loaders), 2) - self.assertEqual(data_loaders['channel1'], data_loader1) - self.assertEqual(data_loaders['channel2'], data_loader2) + self.assertEqual(data_loaders["channel1"], data_loader1) + self.assertEqual(data_loaders["channel2"], data_loader2) self.assertEqual(data_loaders[0], data_loader1) self.assertEqual(data_loaders[1], data_loader2) - self.assertEqual(data_loaders.loaderOf('input1'), data_loader1) - self.assertEqual(data_loaders.loaderOf('input2'), data_loader1) - self.assertEqual(data_loaders.loaderOf('input3'), data_loader2) + self.assertEqual(data_loaders.loaderOf("input1"), data_loader1) + self.assertEqual(data_loaders.loaderOf("input2"), data_loader1) + self.assertEqual(data_loaders.loaderOf("input3"), data_loader2)