From a1314e17d883279613a9465e946c9076944c8ee4 Mon Sep 17 00:00:00 2001
From: Philip ABBET <philip.abbet@idiap.ch>
Date: Fri, 20 Oct 2017 11:05:59 +0200
Subject: [PATCH] [unittests] Refactoring of the tests for the 'worker.py'
 script

---
 beat/core/test/test_worker.py | 367 ++++++++++++++++++----------------
 1 file changed, 199 insertions(+), 168 deletions(-)

diff --git a/beat/core/test/test_worker.py b/beat/core/test/test_worker.py
index e5b73e58..1e5779c6 100644
--- a/beat/core/test/test_worker.py
+++ b/beat/core/test/test_worker.py
@@ -37,6 +37,7 @@ import unittest
 import simplejson
 import multiprocessing
 from time import time
+from time import sleep
 
 from ..scripts import worker
 from ..worker import WorkerController
@@ -56,35 +57,35 @@ WORKER2 = 'worker2'
 
 
 CONFIGURATION1 = {
-  'queue': 'queue',
-  'inputs': {
-    'in_data': {
-      'set': 'double',
-      'protocol': 'double',
-      'database': 'integers_db/1',
-      'output': 'a',
-      'path': 'ec/89/e5/6e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55',
-      'endpoint': 'a',
-      'hash': 'ec89e56e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55',
-      'channel': 'integers'
-    }
-  },
-  'algorithm': 'user/integers_echo/1',
-  'parameters': {},
-  'environment': {
-    'name': 'Python 2.7',
-    'version': '1.2.0'
-  },
-  'outputs': {
-    'out_data': {
-      'path': '20/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
-      'endpoint': 'out_data',
-      'hash': '2061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
-      'channel': 'integers'
-    }
-  },
-  'nb_slots': 1,
-  'channel': 'integers'
+    'queue': 'queue',
+    'inputs': {
+        'in_data': {
+            'set': 'double',
+            'protocol': 'double',
+            'database': 'integers_db/1',
+            'output': 'a',
+            'path': 'ec/89/e5/6e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55',
+            'endpoint': 'a',
+            'hash': 'ec89e56e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55',
+            'channel': 'integers'
+        }
+    },
+    'algorithm': 'user/integers_echo/1',
+    'parameters': {},
+    'environment': {
+        'name': 'Python 2.7',
+        'version': '1.2.0'
+    },
+    'outputs': {
+        'out_data': {
+            'path': '20/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
+            'endpoint': 'out_data',
+            'hash': '2061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
+            'channel': 'integers'
+        }
+    },
+    'nb_slots': 1,
+    'channel': 'integers'
 }
 
 
@@ -92,35 +93,35 @@ CONFIGURATION1 = {
 
 
 CONFIGURATION2 = {
-  'queue': 'queue',
-  'inputs': {
-    'in_data': {
-      'set': 'double',
-      'protocol': 'double',
-      'database': 'integers_db/1',
-      'output': 'a',
-      'path': 'ec/89/e5/6e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55',
-      'endpoint': 'a',
-      'hash': 'ec89e56e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55',
-      'channel': 'integers'
-    }
-  },
-  'algorithm': 'user/integers_echo/1',
-  'parameters': {},
-  'environment': {
-    'name': 'Python 2.7',
-    'version': '1.2.0'
-  },
-  'outputs': {
-    'out_data': {
-      'path': '40/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
-      'endpoint': 'out_data',
-      'hash': '4061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
-      'channel': 'integers'
-    }
-  },
-  'nb_slots': 1,
-  'channel': 'integers'
+    'queue': 'queue',
+    'inputs': {
+        'in_data': {
+            'set': 'double',
+            'protocol': 'double',
+            'database': 'integers_db/1',
+            'output': 'a',
+            'path': 'ec/89/e5/6e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55',
+            'endpoint': 'a',
+            'hash': 'ec89e56e161d2cb012ef6ac8acf59bf453a6328766f90dc9baba9eb14ea23c55',
+            'channel': 'integers'
+        }
+    },
+    'algorithm': 'user/integers_echo/1',
+    'parameters': {},
+    'environment': {
+        'name': 'Python 2.7',
+        'version': '1.2.0'
+    },
+    'outputs': {
+        'out_data': {
+            'path': '40/61/b6/2df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
+            'endpoint': 'out_data',
+            'hash': '4061b62df3c3bedd5366f4a625c5d87ffbf5a26007c46c456e9abf21b46c6681',
+            'channel': 'integers'
+        }
+    },
+    'nb_slots': 1,
+    'channel': 'integers'
 }
 
 
@@ -144,65 +145,173 @@ class WorkerProcess(multiprocessing.Process):
 #----------------------------------------------------------
 
 
-class TestOneWorker(unittest.TestCase):
+class TestWorkerBase(unittest.TestCase):
 
     def __init__(self, methodName='runTest'):
-        super(TestOneWorker, self).__init__(methodName)
+        super(TestWorkerBase, self).__init__(methodName)
         self.controller = None
-        self.worker_process = None
+        self.connected_workers = []
+        self.worker_processes = {}
         self.docker = False
 
 
     def setUp(self):
-        self.tearDown()   # In case another test failed badly during its setUp()
+        self.shutdown_everything()  # In case another test failed badly during its setUp()
+
+
+    def tearDown(self):
+        self.shutdown_everything()
+
+
+    def shutdown_everything(self):
+        for name in list(self.worker_processes.keys()):
+            self.stop_worker(name)
+
+        self.worker_processes = {}
+        self.connected_workers = []
 
-        connected_workers = []
+        self.stop_controller()
+
+
+    def start_controller(self, port=None):
+        self.connected_workers = []
 
         def onWorkerReady(name):
-            connected_workers.append(name)
+            self.connected_workers.append(name)
+
+        def onWorkerGone(name):
+            self.connected_workers.remove(name)
 
         self.controller = WorkerController(
-          '127.0.0.1',
-          port=None,
-          callbacks=dict(
-            onWorkerReady = onWorkerReady,
-          )
+            '127.0.0.1',
+            port=port,
+            callbacks=dict(
+                onWorkerReady = onWorkerReady,
+                onWorkerGone = onWorkerGone,
+            )
         )
 
+        self.controller.process(100)
+
+
+    def stop_controller(self):
+        if self.controller is not None:
+            self.controller.destroy()
+            self.controller = None
+
+
+    def start_worker(self, name, address=None):
         args = [
           '--prefix=%s' % prefix,
           '--cache=%s' % tmp_prefix,
-          '--name=%s' % WORKER1,
+          '--name=%s' % name,
           # '-vvv',
-          self.controller.address,
+          self.controller.address if address is None else address,
         ]
 
         if self.docker:
             args.insert(3, '--docker')
 
-        self.worker_process = WorkerProcess(multiprocessing.Queue(), args)
-        self.worker_process.start()
+        worker_process = WorkerProcess(multiprocessing.Queue(), args)
+        worker_process.start()
+
+        worker_process.queue.get()
+
+        self.worker_processes[name] = worker_process
 
-        self.worker_process.queue.get()
 
+    def stop_worker(self, name):
+        if name in self.worker_processes:
+            self.worker_processes[name].terminate()
+            self.worker_processes[name].join()
+            del self.worker_processes[name]
+
+
+    def wait_for_worker_connection(self, name):
         start = time()
-        while len(self.controller.workers) == 0:
+        while name not in self.connected_workers:
             self.assertTrue(self.controller.process(100) is None)
             self.assertTrue(time() - start < 10)  # Exit after 10 seconds
 
+        self.assertTrue(name in self.controller.workers)
+
+
+    def wait_for_worker_disconnection(self, name):
+        start = time()
+        while name in self.connected_workers:
+            self.assertTrue(self.controller.process(100) is None)
+            self.assertTrue(time() - start < 10)  # Exit after 10 seconds
+
+        self.assertTrue(name not in self.controller.workers)
+
+
+#----------------------------------------------------------
+
+
+class TestConnection(TestWorkerBase):
+
+    def test_worker_connection(self):
+        self.start_controller()
+
+        self.assertEqual(len(self.connected_workers), 0)
+        self.assertEqual(len(self.controller.workers), 0)
+
+        self.start_worker(WORKER1)
+
+        self.wait_for_worker_connection(WORKER1)
+
+        self.assertEqual(len(self.connected_workers), 1)
         self.assertEqual(len(self.controller.workers), 1)
-        self.assertTrue(WORKER1 in self.controller.workers)
 
-        self.assertEqual(len(connected_workers), 1)
-        self.assertTrue(WORKER1 in connected_workers)
 
+    def test_worker_disconnection(self):
+        self.start_controller()
+        self.start_worker(WORKER1)
 
-    def tearDown(self):
-        self.stop_worker()
+        self.wait_for_worker_connection(WORKER1)
 
-        if self.controller is not None:
-            self.controller.destroy()
-            self.controller = None
+        self.stop_worker(WORKER1)
+
+        self.wait_for_worker_disconnection(WORKER1)
+
+
+    def test_two_workers_connection(self):
+        self.start_controller()
+
+        self.assertEqual(len(self.connected_workers), 0)
+        self.assertEqual(len(self.controller.workers), 0)
+
+        self.start_worker(WORKER1)
+        self.start_worker(WORKER2)
+
+        self.wait_for_worker_connection(WORKER1)
+        self.wait_for_worker_connection(WORKER2)
+
+        self.assertEqual(len(self.connected_workers), 2)
+        self.assertEqual(len(self.controller.workers), 2)
+
+
+    def test_scheduler_last(self):
+        self.start_worker(WORKER1, address='tcp://127.0.0.1:51000')
+        sleep(1)
+
+        self.start_controller(port=51000)
+
+        self.wait_for_worker_connection(WORKER1)
+
+
+#----------------------------------------------------------
+
+
+class TestOneWorker(TestWorkerBase):
+
+
+    def setUp(self):
+        super(TestOneWorker, self).setUp()
+
+        self.start_controller()
+        self.start_worker(WORKER1)
+        self.wait_for_worker_connection(WORKER1)
 
 
     def _wait(self, max=100):
@@ -216,13 +325,6 @@ class TestOneWorker(unittest.TestCase):
         return message
 
 
-    def stop_worker(self):
-        if self.worker_process is not None:
-            self.worker_process.terminate()
-            self.worker_process.join()
-            self.worker_process = None
-
-
     def _check_done(self, message, expected_worker, expected_job_id):
         self.assertTrue(message is not None)
 
@@ -291,21 +393,6 @@ class TestOneWorker(unittest.TestCase):
         self.assertTrue(len(data) > 0)
 
 
-    def test_worker_shutdown(self):
-        did_shutdown = True
-
-        def onWorkerGone(name):
-            did_shutdown = (name == WORKER1)
-
-        self.controller.callbacks.onWorkerGone = onWorkerGone
-
-        self.stop_worker()
-
-        self.assertTrue(self.controller.process(2000) is None)
-        self.assertEqual(len(self.controller.workers), 0)
-        self.assertTrue(did_shutdown)
-
-
     def test_multiple_jobs(self):
         config = dict(CONFIGURATION1)
         config['algorithm'] = 'user/integers_echo_slow/1'
@@ -374,74 +461,18 @@ class TestOneWorkerDocker(TestOneWorker):
 #----------------------------------------------------------
 
 
-class TestTwoWorkers(unittest.TestCase):
-
-    def __init__(self, methodName='runTest'):
-        super(TestTwoWorkers, self).__init__(methodName)
-        self.controller = None
-        self.worker_processes = None
-        self.docker = False
-
+class TestTwoWorkers(TestWorkerBase):
 
     def setUp(self):
         self.tearDown()   # In case another test failed badly during its setUp()
 
-        connected_workers = []
-
-        def onWorkerReady(name):
-            connected_workers.append(name)
-
-        self.controller = WorkerController(
-          '127.0.0.1',
-          port=None,
-          callbacks=dict(
-            onWorkerReady = onWorkerReady,
-          )
-        )
-
-        self.worker_processes = []
-
-        for name in [ WORKER1, WORKER2 ]:
-            args = [
-                '--prefix=%s' % prefix,
-                '--cache=%s' % tmp_prefix,
-                '--name=%s' % name,
-                self.controller.address,
-            ]
-
-            if self.docker:
-                args.insert(3, '--docker')
-
-            worker_process = WorkerProcess(multiprocessing.Queue(), args)
-            worker_process.start()
-
-            worker_process.queue.get()
-
-            self.worker_processes.append(worker_process)
+        super(TestTwoWorkers, self).setUp()
 
-        start = time()
-        while len(self.controller.workers) < 2:
-            self.assertTrue(self.controller.process(100) is None)
-            self.assertTrue(time() - start < 10)  # Exit after 10 seconds
-
-        self.assertTrue(WORKER1 in self.controller.workers)
-        self.assertTrue(WORKER2 in self.controller.workers)
-
-        self.assertEqual(len(connected_workers), 2)
-        self.assertTrue(WORKER1 in connected_workers)
-        self.assertTrue(WORKER2 in connected_workers)
-
-
-    def tearDown(self):
-        if self.worker_processes is not None:
-            for worker_process in self.worker_processes:
-                worker_process.terminate()
-                worker_process.join()
-            self.worker_processes = None
-
-        if self.controller is not None:
-            self.controller.destroy()
-            self.controller = None
+        self.start_controller()
+        self.start_worker(WORKER1)
+        self.start_worker(WORKER2)
+        self.wait_for_worker_connection(WORKER1)
+        self.wait_for_worker_connection(WORKER2)
 
 
     def _test_success_one_worker(self, worker_name):
-- 
GitLab