From 761414be8efbdf5c01a7735c4f59d31d0d8dce25 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 22 Apr 2022 12:30:48 +0200
Subject: [PATCH] [factor_analysis] make sure state changes work through dask
 as well

---
 bob/learn/em/factor_analysis.py           | 15 +++++++++------
 bob/learn/em/test/test_factor_analysis.py |  4 ++--
 2 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py
index 6dd609e..6e8fdef 100644
--- a/bob/learn/em/factor_analysis.py
+++ b/bob/learn/em/factor_analysis.py
@@ -571,6 +571,7 @@ class FactorAnalysisBase(BaseEstimator):
         self._U = U_c.reshape(
             self.ubm.n_gaussians * self.feature_dimension, self.r_U
         )
+        return self._U
 
     def _compute_uprod(self):
         """
@@ -1403,7 +1404,7 @@ class ISVMachine(FactorAnalysisBase):
         acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list]
         acc_U_A2 = reduce_iadd(acc_U_A2)
 
-        self.update_U(acc_U_A1, acc_U_A2)
+        return self.update_U(acc_U_A1, acc_U_A2)
 
     def fit_using_stats(self, X, y):
         """
@@ -1444,7 +1445,7 @@ class ISVMachine(FactorAnalysisBase):
                     for xx, yy in zip(X, y)
                 ]
                 delayed_em_step = dask.delayed(self.m_step)(e_step_output)
-                dask.compute(delayed_em_step)
+                self._U = dask.compute(delayed_em_step)[0]
             else:
                 e_step_output = self.e_step(
                     X=X,
@@ -1724,6 +1725,7 @@ class JFAMachine(FactorAnalysisBase):
         self._V = V_c.reshape(
             (self.ubm.n_gaussians * self.feature_dimension, self.r_V)
         )
+        return self._V
 
     def finalize_v(self, X, y, n_samples_per_class, n_acc, f_acc):
         """
@@ -1849,7 +1851,7 @@ class JFAMachine(FactorAnalysisBase):
         acc_U_A1 = reduce_iadd(acc_U_A1)
         acc_U_A2 = reduce_iadd(acc_U_A2)
 
-        self.update_U(acc_U_A1, acc_U_A2)
+        return self.update_U(acc_U_A1, acc_U_A2)
 
     def finalize_u(
         self,
@@ -1984,6 +1986,7 @@ class JFAMachine(FactorAnalysisBase):
         acc_D_A2 = reduce_iadd(acc_D_A2)
 
         self._D = acc_D_A2 / acc_D_A1
+        return self._D
 
     def enroll_using_stats(self, X, iterations=1):
         """
@@ -2119,7 +2122,7 @@ class JFAMachine(FactorAnalysisBase):
                     for xx, yy in zip(X, y)
                 ]
                 delayed_em_step = dask.delayed(self.m_step_v)(e_step_output)
-                dask.compute(delayed_em_step)
+                self._V = dask.compute(delayed_em_step)[0]
             else:
                 e_step_output = self.e_step_v(
                     X=X,
@@ -2151,7 +2154,7 @@ class JFAMachine(FactorAnalysisBase):
                     for xx, yy in zip(X, y)
                 ]
                 delayed_em_step = dask.delayed(self.m_step_u)(e_step_output)
-                dask.compute(delayed_em_step)
+                self._U = dask.compute(delayed_em_step)[0]
             else:
                 e_step_output = self.e_step_u(
                     X=X,
@@ -2182,7 +2185,7 @@ class JFAMachine(FactorAnalysisBase):
                     for xx, yy in zip(X, y)
                 ]
                 delayed_em_step = dask.delayed(self.m_step_d)(e_step_output)
-                dask.compute(delayed_em_step)
+                self._D = dask.compute(delayed_em_step)[0]
             else:
                 e_step_output = self.e_step_d(
                     X=X,
diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py
index 8d82d0c..524ad18 100644
--- a/bob/learn/em/test/test_factor_analysis.py
+++ b/bob/learn/em/test/test_factor_analysis.py
@@ -556,8 +556,8 @@ def test_ISV_JFA_fit():
                 test_attr = "V"
 
             err_msg = f"Test failed with prior={prior} and machine_type={machine_type} and transform={transform}"
-            # with multiprocess_dask_client():
-            machine.fit(data, labels)
+            with multiprocess_dask_client():
+                machine.fit(data, labels)
 
             arr = getattr(machine, test_attr)
             np.testing.assert_allclose(
-- 
GitLab