Skip to content
Snippets Groups Projects
Commit 761414be authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[factor_analysis] make sure state changes work through dask as well

parent 5f49f8e2
Branches
No related tags found
1 merge request!53Factor Analysis on pure python
Pipeline #60443 failed
...@@ -571,6 +571,7 @@ class FactorAnalysisBase(BaseEstimator): ...@@ -571,6 +571,7 @@ class FactorAnalysisBase(BaseEstimator):
self._U = U_c.reshape( self._U = U_c.reshape(
self.ubm.n_gaussians * self.feature_dimension, self.r_U self.ubm.n_gaussians * self.feature_dimension, self.r_U
) )
return self._U
def _compute_uprod(self): def _compute_uprod(self):
""" """
...@@ -1403,7 +1404,7 @@ class ISVMachine(FactorAnalysisBase): ...@@ -1403,7 +1404,7 @@ class ISVMachine(FactorAnalysisBase):
acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list] acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list]
acc_U_A2 = reduce_iadd(acc_U_A2) acc_U_A2 = reduce_iadd(acc_U_A2)
self.update_U(acc_U_A1, acc_U_A2) return self.update_U(acc_U_A1, acc_U_A2)
def fit_using_stats(self, X, y): def fit_using_stats(self, X, y):
""" """
...@@ -1444,7 +1445,7 @@ class ISVMachine(FactorAnalysisBase): ...@@ -1444,7 +1445,7 @@ class ISVMachine(FactorAnalysisBase):
for xx, yy in zip(X, y) for xx, yy in zip(X, y)
] ]
delayed_em_step = dask.delayed(self.m_step)(e_step_output) delayed_em_step = dask.delayed(self.m_step)(e_step_output)
dask.compute(delayed_em_step) self._U = dask.compute(delayed_em_step)[0]
else: else:
e_step_output = self.e_step( e_step_output = self.e_step(
X=X, X=X,
...@@ -1724,6 +1725,7 @@ class JFAMachine(FactorAnalysisBase): ...@@ -1724,6 +1725,7 @@ class JFAMachine(FactorAnalysisBase):
self._V = V_c.reshape( self._V = V_c.reshape(
(self.ubm.n_gaussians * self.feature_dimension, self.r_V) (self.ubm.n_gaussians * self.feature_dimension, self.r_V)
) )
return self._V
def finalize_v(self, X, y, n_samples_per_class, n_acc, f_acc): def finalize_v(self, X, y, n_samples_per_class, n_acc, f_acc):
""" """
...@@ -1849,7 +1851,7 @@ class JFAMachine(FactorAnalysisBase): ...@@ -1849,7 +1851,7 @@ class JFAMachine(FactorAnalysisBase):
acc_U_A1 = reduce_iadd(acc_U_A1) acc_U_A1 = reduce_iadd(acc_U_A1)
acc_U_A2 = reduce_iadd(acc_U_A2) acc_U_A2 = reduce_iadd(acc_U_A2)
self.update_U(acc_U_A1, acc_U_A2) return self.update_U(acc_U_A1, acc_U_A2)
def finalize_u( def finalize_u(
self, self,
...@@ -1984,6 +1986,7 @@ class JFAMachine(FactorAnalysisBase): ...@@ -1984,6 +1986,7 @@ class JFAMachine(FactorAnalysisBase):
acc_D_A2 = reduce_iadd(acc_D_A2) acc_D_A2 = reduce_iadd(acc_D_A2)
self._D = acc_D_A2 / acc_D_A1 self._D = acc_D_A2 / acc_D_A1
return self._D
def enroll_using_stats(self, X, iterations=1): def enroll_using_stats(self, X, iterations=1):
""" """
...@@ -2119,7 +2122,7 @@ class JFAMachine(FactorAnalysisBase): ...@@ -2119,7 +2122,7 @@ class JFAMachine(FactorAnalysisBase):
for xx, yy in zip(X, y) for xx, yy in zip(X, y)
] ]
delayed_em_step = dask.delayed(self.m_step_v)(e_step_output) delayed_em_step = dask.delayed(self.m_step_v)(e_step_output)
dask.compute(delayed_em_step) self._V = dask.compute(delayed_em_step)[0]
else: else:
e_step_output = self.e_step_v( e_step_output = self.e_step_v(
X=X, X=X,
...@@ -2151,7 +2154,7 @@ class JFAMachine(FactorAnalysisBase): ...@@ -2151,7 +2154,7 @@ class JFAMachine(FactorAnalysisBase):
for xx, yy in zip(X, y) for xx, yy in zip(X, y)
] ]
delayed_em_step = dask.delayed(self.m_step_u)(e_step_output) delayed_em_step = dask.delayed(self.m_step_u)(e_step_output)
dask.compute(delayed_em_step) self._U = dask.compute(delayed_em_step)[0]
else: else:
e_step_output = self.e_step_u( e_step_output = self.e_step_u(
X=X, X=X,
...@@ -2182,7 +2185,7 @@ class JFAMachine(FactorAnalysisBase): ...@@ -2182,7 +2185,7 @@ class JFAMachine(FactorAnalysisBase):
for xx, yy in zip(X, y) for xx, yy in zip(X, y)
] ]
delayed_em_step = dask.delayed(self.m_step_d)(e_step_output) delayed_em_step = dask.delayed(self.m_step_d)(e_step_output)
dask.compute(delayed_em_step) self._D = dask.compute(delayed_em_step)[0]
else: else:
e_step_output = self.e_step_d( e_step_output = self.e_step_d(
X=X, X=X,
......
...@@ -556,8 +556,8 @@ def test_ISV_JFA_fit(): ...@@ -556,8 +556,8 @@ def test_ISV_JFA_fit():
test_attr = "V" test_attr = "V"
err_msg = f"Test failed with prior={prior} and machine_type={machine_type} and transform={transform}" err_msg = f"Test failed with prior={prior} and machine_type={machine_type} and transform={transform}"
# with multiprocess_dask_client(): with multiprocess_dask_client():
machine.fit(data, labels) machine.fit(data, labels)
arr = getattr(machine, test_attr) arr = getattr(machine, test_attr)
np.testing.assert_allclose( np.testing.assert_allclose(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment