From 6b84f7ddbb9524e4d2b7e88ad9d350175355dfb0 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Wed, 16 Mar 2022 11:18:32 +0100
Subject: [PATCH] [gmm] fix gmm training when dask workers are in a different
 process

---
 bob/learn/em/gmm.py           |  8 +++++---
 bob/learn/em/test/test_gmm.py | 26 +++++++++++++++++++-------
 2 files changed, 24 insertions(+), 10 deletions(-)

diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index 436d169..88efe81 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -119,7 +119,7 @@ def m_step(
         relevance_factor=map_relevance_factor,
     )
     average_output = float(statistics.log_likelihood / statistics.t)
-    return average_output
+    return machine, average_output
 
 
 class GMMStats:
@@ -888,9 +888,11 @@ class GMMMachine(BaseEstimator):
                     )
                     for xx in X
                 ]
-                average_output = dask.compute(
+                new_machine, average_output = dask.compute(
                     dask.delayed(m_step_func)(self, stats)
                 )[0]
+                for attr in ["weights", "means", "variances"]:
+                    setattr(self, attr, getattr(new_machine, attr))
             else:
                 stats = [
                     e_step(
@@ -902,7 +904,7 @@ class GMMMachine(BaseEstimator):
                         log_weights=self.log_weights,
                     )
                 ]
-                average_output = m_step_func(self, stats)
+                _, average_output = m_step_func(self, stats)
 
             logger.debug(f"log likelihood = {average_output}")
             if step > 1:
diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py
index 7eda1c6..3ad780a 100644
--- a/bob/learn/em/test/test_gmm.py
+++ b/bob/learn/em/test/test_gmm.py
@@ -16,11 +16,14 @@ from copy import deepcopy
 import dask.array as da
 import numpy as np
 
+from dask.distributed import Client
 from h5py import File as HDF5File
 from pkg_resources import resource_filename
 
 from bob.learn.em import GMMMachine, GMMStats, KMeansMachine
 
+from .test_kmeans import to_dask_array, to_numpy
+
 
 def load_array(filename):
     with HDF5File(filename, "r") as f:
@@ -464,13 +467,22 @@ def test_gmm_kmeans_parallel_init():
     data = np.array(
         [[1.5, 1], [1, 1.5], [-1, 0.5], [-1.5, 0], [2, 2], [2.5, 2.5]]
     )
-    machine = machine.fit(data)
-    expected_means = np.array([[1.25, 1.25], [-1.25, 0.25], [2.25, 2.25]])
-    expected_variances = np.array(
-        [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]]
-    )
-    np.testing.assert_almost_equal(machine.means, expected_means, decimal=3)
-    np.testing.assert_almost_equal(machine.variances, expected_variances)
+    with Client().as_current():
+        for transform in (to_numpy, to_dask_array):
+            data = transform(data)
+            machine = machine.fit(data)
+            expected_means = np.array(
+                [[1.25, 1.25], [-1.25, 0.25], [2.25, 2.25]]
+            )
+            expected_variances = np.array(
+                [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]]
+            )
+            np.testing.assert_almost_equal(
+                machine.means, expected_means, decimal=3
+            )
+            np.testing.assert_almost_equal(
+                machine.variances, expected_variances
+            )
 
 
 def test_likelihood():
-- 
GitLab