diff --git a/beat/core/test/test_worker.py b/beat/core/test/test_worker.py index 2d8196a1cae0b4de6654810138fb61b3ac4f918d..6249c4707d0edfa4a082ade6f6d253e22043686c 100644 --- a/beat/core/test/test_worker.py +++ b/beat/core/test/test_worker.py @@ -46,6 +46,9 @@ import queue from time import time from time import sleep +from ddt import ddt +from ddt import idata + from ..scripts import worker from ..worker import WorkerController from ..database import Database @@ -66,6 +69,8 @@ PORT = find_free_port() # ---------------------------------------------------------- +DATABASES = [f"integers_db/{i}" for i in range(1, 3)] + CONFIGURATION1 = { "queue": "queue", @@ -133,6 +138,19 @@ CONFIGURATION2 = { # ---------------------------------------------------------- +def prepare_database(db_name): + CONFIGURATION1["inputs"]["in"]["database"] = db_name + CONFIGURATION2["inputs"]["in"]["database"] = db_name + + for _, input_cfg in CONFIGURATION1["inputs"].items(): + database = Database(prefix, input_cfg["database"]) + view = database.view(input_cfg["protocol"], input_cfg["set"]) + view.index(os.path.join(tmp_prefix, input_cfg["path"])) + + +# ---------------------------------------------------------- + + class ControllerProcess(multiprocessing.Process): def __init__(self, queue): super(ControllerProcess, self).__init__() @@ -271,12 +289,6 @@ class TestWorkerBase(unittest.TestCase): self.assertTrue(name not in self.controller.workers) - def prepare_databases(self, configuration): - for _, input_cfg in configuration["inputs"].items(): - database = Database(prefix, input_cfg["database"]) - view = database.view(input_cfg["protocol"], input_cfg["set"]) - view.index(os.path.join(tmp_prefix, input_cfg["path"])) - # ---------------------------------------------------------- @@ -361,6 +373,7 @@ class TestConnection(TestWorkerBase): # ---------------------------------------------------------- +@ddt class TestOneWorker(TestWorkerBase): def setUp(self): super(TestOneWorker, self).setUp() @@ -370,8 +383,6 @@ class TestOneWorker(TestWorkerBase): self.wait_for_worker_connection(WORKER1) - self.prepare_databases(CONFIGURATION1) - def _wait(self, max=200): message = None nb = 0 @@ -398,14 +409,20 @@ class TestOneWorker(TestWorkerBase): self.assertEqual(result["status"], 0) - def test_success(self): + @idata(DATABASES) + def test_success(self, db_name): + prepare_database(db_name) + self.controller.execute(WORKER1, 1, CONFIGURATION1) message = self._wait() self._check_done(message, WORKER1, 1) - def test_processing_error(self): + @idata(DATABASES) + def test_processing_error(self, db_name): + prepare_database(db_name) + config = dict(CONFIGURATION1) config["algorithm"] = "legacy/process_crash/1" @@ -424,7 +441,10 @@ class TestOneWorker(TestWorkerBase): self.assertEqual(result["status"], 1) self.assertTrue("a = b" in result["user_error"]) - def test_error_unknown_algorithm(self): + @idata(DATABASES) + def test_error_unknown_algorithm(self, db_name): + prepare_database(db_name) + config = dict(CONFIGURATION1) config["algorithm"] = "user/unknown/1" @@ -439,7 +459,10 @@ class TestOneWorker(TestWorkerBase): self.assertEqual(job_id, 1) self.assertTrue(len(data) > 0) - def test_error_syntax_error(self): + @idata(DATABASES) + def test_error_syntax_error(self, db_name): + prepare_database(db_name) + config = dict(CONFIGURATION1) config["algorithm"] = "legacy/syntax_error/1" @@ -454,7 +477,10 @@ class TestOneWorker(TestWorkerBase): self.assertEqual(job_id, 1) self.assertTrue(len(data) > 0) - def test_multiple_jobs(self): + @idata(DATABASES) + def test_multiple_jobs(self, db_name): + prepare_database(db_name) + config = dict(CONFIGURATION1) config["algorithm"] = "user/integers_echo_slow/1" @@ -467,7 +493,10 @@ class TestOneWorker(TestWorkerBase): message = self._wait() self._check_done(message, WORKER1, 2) - def test_reuse(self): + @idata(DATABASES) + def test_reuse(self, db_name): + prepare_database(db_name) + self.controller.execute(WORKER1, 1, CONFIGURATION1) message = self._wait() self._check_done(message, WORKER1, 1) @@ -476,7 +505,10 @@ class TestOneWorker(TestWorkerBase): message = self._wait() self._check_done(message, WORKER1, 2) - def test_cancel(self): + @idata(DATABASES) + def test_cancel(self, db_name): + prepare_database(db_name) + config = dict(CONFIGURATION1) config["algorithm"] = "user/integers_echo_slow/1" @@ -508,6 +540,7 @@ class TestOneWorker(TestWorkerBase): # ---------------------------------------------------------- +@ddt class TestTwoWorkers(TestWorkerBase): def setUp(self): self.tearDown() # In case another test failed badly during its setUp() @@ -520,7 +553,9 @@ class TestTwoWorkers(TestWorkerBase): self.wait_for_worker_connection(WORKER1) self.wait_for_worker_connection(WORKER2) - def _test_success_one_worker(self, worker_name): + def _test_success_one_worker(self, worker_name, db_name): + prepare_database(db_name) + self.controller.execute(worker_name, 1, CONFIGURATION1) message = None @@ -538,13 +573,16 @@ class TestTwoWorkers(TestWorkerBase): self.assertEqual(result["status"], 0) - def test_success_worker1(self): - self._test_success_one_worker(WORKER1) + @idata(DATABASES) + def test_success_worker1(self, db_name): + self._test_success_one_worker(WORKER1, db_name) - def test_success_worker2(self): - self._test_success_one_worker(WORKER2) + @idata(DATABASES) + def test_success_worker2(self, db_name): + self._test_success_one_worker(WORKER2, db_name) - def test_success_both_workers(self): + @idata(DATABASES) + def test_success_both_workers(self, db_name): def _check(worker, status, job_id, data): self.assertEqual(status, WorkerController.DONE) @@ -557,6 +595,8 @@ class TestTwoWorkers(TestWorkerBase): result = json.loads(data[0]) self.assertEqual(result["status"], 0) + prepare_database(db_name) + self.controller.execute(WORKER1, 1, CONFIGURATION1) self.controller.execute(WORKER2, 2, CONFIGURATION2)