From 06d5d38686ec8f838af0929569597fb858f4a924 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Tue, 22 Mar 2022 18:22:57 +0100 Subject: [PATCH] close dask client to see if it prevents tests hanging --- bob/learn/em/test/test_gmm.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py index 3ad780a..b62da1a 100644 --- a/bob/learn/em/test/test_gmm.py +++ b/bob/learn/em/test/test_gmm.py @@ -8,6 +8,7 @@ """Tests the GMM machine and the GMMStats container """ +import contextlib import os import tempfile @@ -32,6 +33,16 @@ def load_array(filename): return np.squeeze(array) +@contextlib.contextmanager +def multiprocess_dask_client(): + try: + client = Client() + with client.as_current(): + yield client + finally: + client.close() + + def test_GMMStats(): # Test a GMMStats # Initializes a GMMStats @@ -467,7 +478,7 @@ def test_gmm_kmeans_parallel_init(): data = np.array( [[1.5, 1], [1, 1.5], [-1, 0.5], [-1.5, 0], [2, 2], [2.5, 2.5]] ) - with Client().as_current(): + with multiprocess_dask_client(): for transform in (to_numpy, to_dask_array): data = transform(data) machine = machine.fit(data) @@ -778,7 +789,7 @@ def test_gmm_ML_1(): def test_gmm_ML_2(): - """Trains a GMMMachine with ML_GMMTrainer; compares to a reference""" + # Trains a GMMMachine with ML_GMMTrainer; compares to a reference ar = load_array( resource_filename("bob.learn.em", "data/dataNormalized.hdf5") ) @@ -829,7 +840,7 @@ def test_gmm_ML_2(): def test_gmm_MAP_1(): - """Train a GMMMachine with MAP_GMMTrainer""" + # Train a GMMMachine with MAP_GMMTrainer ar = load_array( resource_filename("bob.learn.em", "data/faithful.torch3_f64.hdf5") ) @@ -875,7 +886,7 @@ def test_gmm_MAP_1(): def test_gmm_MAP_2(): - """Train a GMMMachine with MAP_GMMTrainer and compare with matlab reference""" + # Train a GMMMachine with MAP_GMMTrainer and compare with matlab reference data = load_array(resource_filename("bob.learn.em", "data/data.hdf5")) data = data.reshape((1, -1)) # make a 2D array out of it @@ -915,7 +926,7 @@ def test_gmm_MAP_2(): def test_gmm_MAP_3(): - """Train a GMMMachine with MAP_GMMTrainer; compares to old reference""" + # Train a GMMMachine with MAP_GMMTrainer; compares to old reference ar = load_array(resource_filename("bob.learn.em", "data/dataforMAP.hdf5")) # Initialize GMMMachine @@ -976,7 +987,7 @@ def test_gmm_MAP_3(): def test_gmm_test(): - """Tests a GMMMachine by computing scores against a model and comparing to a reference""" + # Tests a GMMMachine by computing scores against a model and comparing to a reference ar = load_array(resource_filename("bob.learn.em", "data/dataforMAP.hdf5")) @@ -1006,7 +1017,7 @@ def test_gmm_test(): def test_gmm_ML_dask(): - """Trains a GMMMachine with dask array data; compares to a reference""" + # Trains a GMMMachine with dask array data; compares to a reference ar = da.array( load_array( @@ -1061,7 +1072,7 @@ def test_gmm_ML_dask(): def test_gmm_MAP_dask(): - """Test a GMMMachine for MAP with a dask array as data.""" + # Test a GMMMachine for MAP with a dask array as data. ar = da.array( load_array(resource_filename("bob.learn.em", "data/dataforMAP.hdf5")) ) -- GitLab