From 06d5d38686ec8f838af0929569597fb858f4a924 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Tue, 22 Mar 2022 18:22:57 +0100
Subject: [PATCH] close dask client to see if it prevents tests hanging

---
 bob/learn/em/test/test_gmm.py | 27 +++++++++++++++++++--------
 1 file changed, 19 insertions(+), 8 deletions(-)

diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py
index 3ad780a..b62da1a 100644
--- a/bob/learn/em/test/test_gmm.py
+++ b/bob/learn/em/test/test_gmm.py
@@ -8,6 +8,7 @@
 """Tests the GMM machine and the GMMStats container
 """
 
+import contextlib
 import os
 import tempfile
 
@@ -32,6 +33,16 @@ def load_array(filename):
     return np.squeeze(array)
 
 
+@contextlib.contextmanager
+def multiprocess_dask_client():
+    try:
+        client = Client()
+        with client.as_current():
+            yield client
+    finally:
+        client.close()
+
+
 def test_GMMStats():
     # Test a GMMStats
     # Initializes a GMMStats
@@ -467,7 +478,7 @@ def test_gmm_kmeans_parallel_init():
     data = np.array(
         [[1.5, 1], [1, 1.5], [-1, 0.5], [-1.5, 0], [2, 2], [2.5, 2.5]]
     )
-    with Client().as_current():
+    with multiprocess_dask_client():
         for transform in (to_numpy, to_dask_array):
             data = transform(data)
             machine = machine.fit(data)
@@ -778,7 +789,7 @@ def test_gmm_ML_1():
 
 
 def test_gmm_ML_2():
-    """Trains a GMMMachine with ML_GMMTrainer; compares to a reference"""
+    # Trains a GMMMachine with ML_GMMTrainer; compares to a reference
     ar = load_array(
         resource_filename("bob.learn.em", "data/dataNormalized.hdf5")
     )
@@ -829,7 +840,7 @@ def test_gmm_ML_2():
 
 
 def test_gmm_MAP_1():
-    """Train a GMMMachine with MAP_GMMTrainer"""
+    # Train a GMMMachine with MAP_GMMTrainer
     ar = load_array(
         resource_filename("bob.learn.em", "data/faithful.torch3_f64.hdf5")
     )
@@ -875,7 +886,7 @@ def test_gmm_MAP_1():
 
 
 def test_gmm_MAP_2():
-    """Train a GMMMachine with MAP_GMMTrainer and compare with matlab reference"""
+    # Train a GMMMachine with MAP_GMMTrainer and compare with matlab reference
 
     data = load_array(resource_filename("bob.learn.em", "data/data.hdf5"))
     data = data.reshape((1, -1))  # make a 2D array out of it
@@ -915,7 +926,7 @@ def test_gmm_MAP_2():
 
 
 def test_gmm_MAP_3():
-    """Train a GMMMachine with MAP_GMMTrainer; compares to old reference"""
+    # Train a GMMMachine with MAP_GMMTrainer; compares to old reference
     ar = load_array(resource_filename("bob.learn.em", "data/dataforMAP.hdf5"))
 
     # Initialize GMMMachine
@@ -976,7 +987,7 @@ def test_gmm_MAP_3():
 
 
 def test_gmm_test():
-    """Tests a GMMMachine by computing scores against a model and comparing to a reference"""
+    # Tests a GMMMachine by computing scores against a model and comparing to a reference
 
     ar = load_array(resource_filename("bob.learn.em", "data/dataforMAP.hdf5"))
 
@@ -1006,7 +1017,7 @@ def test_gmm_test():
 
 
 def test_gmm_ML_dask():
-    """Trains a GMMMachine with dask array data; compares to a reference"""
+    # Trains a GMMMachine with dask array data; compares to a reference
 
     ar = da.array(
         load_array(
@@ -1061,7 +1072,7 @@ def test_gmm_ML_dask():
 
 
 def test_gmm_MAP_dask():
-    """Test a GMMMachine for MAP with a dask array as data."""
+    # Test a GMMMachine for MAP with a dask array as data.
     ar = da.array(
         load_array(resource_filename("bob.learn.em", "data/dataforMAP.hdf5"))
     )
-- 
GitLab