From 761414be8efbdf5c01a7735c4f59d31d0d8dce25 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Fri, 22 Apr 2022 12:30:48 +0200 Subject: [PATCH] [factor_analysis] make sure state changes work through dask as well --- bob/learn/em/factor_analysis.py | 15 +++++++++------ bob/learn/em/test/test_factor_analysis.py | 4 ++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index 6dd609e..6e8fdef 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -571,6 +571,7 @@ class FactorAnalysisBase(BaseEstimator): self._U = U_c.reshape( self.ubm.n_gaussians * self.feature_dimension, self.r_U ) + return self._U def _compute_uprod(self): """ @@ -1403,7 +1404,7 @@ class ISVMachine(FactorAnalysisBase): acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list] acc_U_A2 = reduce_iadd(acc_U_A2) - self.update_U(acc_U_A1, acc_U_A2) + return self.update_U(acc_U_A1, acc_U_A2) def fit_using_stats(self, X, y): """ @@ -1444,7 +1445,7 @@ class ISVMachine(FactorAnalysisBase): for xx, yy in zip(X, y) ] delayed_em_step = dask.delayed(self.m_step)(e_step_output) - dask.compute(delayed_em_step) + self._U = dask.compute(delayed_em_step)[0] else: e_step_output = self.e_step( X=X, @@ -1724,6 +1725,7 @@ class JFAMachine(FactorAnalysisBase): self._V = V_c.reshape( (self.ubm.n_gaussians * self.feature_dimension, self.r_V) ) + return self._V def finalize_v(self, X, y, n_samples_per_class, n_acc, f_acc): """ @@ -1849,7 +1851,7 @@ class JFAMachine(FactorAnalysisBase): acc_U_A1 = reduce_iadd(acc_U_A1) acc_U_A2 = reduce_iadd(acc_U_A2) - self.update_U(acc_U_A1, acc_U_A2) + return self.update_U(acc_U_A1, acc_U_A2) def finalize_u( self, @@ -1984,6 +1986,7 @@ class JFAMachine(FactorAnalysisBase): acc_D_A2 = reduce_iadd(acc_D_A2) self._D = acc_D_A2 / acc_D_A1 + return self._D def enroll_using_stats(self, X, iterations=1): """ @@ -2119,7 +2122,7 @@ class JFAMachine(FactorAnalysisBase): for xx, yy in zip(X, y) ] delayed_em_step = dask.delayed(self.m_step_v)(e_step_output) - dask.compute(delayed_em_step) + self._V = dask.compute(delayed_em_step)[0] else: e_step_output = self.e_step_v( X=X, @@ -2151,7 +2154,7 @@ class JFAMachine(FactorAnalysisBase): for xx, yy in zip(X, y) ] delayed_em_step = dask.delayed(self.m_step_u)(e_step_output) - dask.compute(delayed_em_step) + self._U = dask.compute(delayed_em_step)[0] else: e_step_output = self.e_step_u( X=X, @@ -2182,7 +2185,7 @@ class JFAMachine(FactorAnalysisBase): for xx, yy in zip(X, y) ] delayed_em_step = dask.delayed(self.m_step_d)(e_step_output) - dask.compute(delayed_em_step) + self._D = dask.compute(delayed_em_step)[0] else: e_step_output = self.e_step_d( X=X, diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py index 8d82d0c..524ad18 100644 --- a/bob/learn/em/test/test_factor_analysis.py +++ b/bob/learn/em/test/test_factor_analysis.py @@ -556,8 +556,8 @@ def test_ISV_JFA_fit(): test_attr = "V" err_msg = f"Test failed with prior={prior} and machine_type={machine_type} and transform={transform}" - # with multiprocess_dask_client(): - machine.fit(data, labels) + with multiprocess_dask_client(): + machine.fit(data, labels) arr = getattr(machine, test_attr) np.testing.assert_allclose( -- GitLab