diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py
index 6dd609e3f7b0aa71c5999d8752128d1d24c895ba..6e8fdefbec11a2f6fcf9176a80bfa2a95901c9ad 100644
--- a/bob/learn/em/factor_analysis.py
+++ b/bob/learn/em/factor_analysis.py
@@ -571,6 +571,7 @@ class FactorAnalysisBase(BaseEstimator):
         self._U = U_c.reshape(
             self.ubm.n_gaussians * self.feature_dimension, self.r_U
         )
+        return self._U
 
     def _compute_uprod(self):
         """
@@ -1403,7 +1404,7 @@ class ISVMachine(FactorAnalysisBase):
         acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list]
         acc_U_A2 = reduce_iadd(acc_U_A2)
 
-        self.update_U(acc_U_A1, acc_U_A2)
+        return self.update_U(acc_U_A1, acc_U_A2)
 
     def fit_using_stats(self, X, y):
         """
@@ -1444,7 +1445,7 @@ class ISVMachine(FactorAnalysisBase):
                     for xx, yy in zip(X, y)
                 ]
                 delayed_em_step = dask.delayed(self.m_step)(e_step_output)
-                dask.compute(delayed_em_step)
+                self._U = dask.compute(delayed_em_step)[0]
             else:
                 e_step_output = self.e_step(
                     X=X,
@@ -1724,6 +1725,7 @@ class JFAMachine(FactorAnalysisBase):
         self._V = V_c.reshape(
             (self.ubm.n_gaussians * self.feature_dimension, self.r_V)
         )
+        return self._V
 
     def finalize_v(self, X, y, n_samples_per_class, n_acc, f_acc):
         """
@@ -1849,7 +1851,7 @@ class JFAMachine(FactorAnalysisBase):
         acc_U_A1 = reduce_iadd(acc_U_A1)
         acc_U_A2 = reduce_iadd(acc_U_A2)
 
-        self.update_U(acc_U_A1, acc_U_A2)
+        return self.update_U(acc_U_A1, acc_U_A2)
 
     def finalize_u(
         self,
@@ -1984,6 +1986,7 @@ class JFAMachine(FactorAnalysisBase):
         acc_D_A2 = reduce_iadd(acc_D_A2)
 
         self._D = acc_D_A2 / acc_D_A1
+        return self._D
 
     def enroll_using_stats(self, X, iterations=1):
         """
@@ -2119,7 +2122,7 @@ class JFAMachine(FactorAnalysisBase):
                     for xx, yy in zip(X, y)
                 ]
                 delayed_em_step = dask.delayed(self.m_step_v)(e_step_output)
-                dask.compute(delayed_em_step)
+                self._V = dask.compute(delayed_em_step)[0]
             else:
                 e_step_output = self.e_step_v(
                     X=X,
@@ -2151,7 +2154,7 @@ class JFAMachine(FactorAnalysisBase):
                     for xx, yy in zip(X, y)
                 ]
                 delayed_em_step = dask.delayed(self.m_step_u)(e_step_output)
-                dask.compute(delayed_em_step)
+                self._U = dask.compute(delayed_em_step)[0]
             else:
                 e_step_output = self.e_step_u(
                     X=X,
@@ -2182,7 +2185,7 @@ class JFAMachine(FactorAnalysisBase):
                     for xx, yy in zip(X, y)
                 ]
                 delayed_em_step = dask.delayed(self.m_step_d)(e_step_output)
-                dask.compute(delayed_em_step)
+                self._D = dask.compute(delayed_em_step)[0]
             else:
                 e_step_output = self.e_step_d(
                     X=X,
diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py
index 8d82d0cca8380af217f78b66db90ccf7490a6516..524ad18e42c504767cbe0ca80793352624f56a9c 100644
--- a/bob/learn/em/test/test_factor_analysis.py
+++ b/bob/learn/em/test/test_factor_analysis.py
@@ -556,8 +556,8 @@ def test_ISV_JFA_fit():
                 test_attr = "V"
 
             err_msg = f"Test failed with prior={prior} and machine_type={machine_type} and transform={transform}"
-            # with multiprocess_dask_client():
-            machine.fit(data, labels)
+            with multiprocess_dask_client():
+                machine.fit(data, labels)
 
             arr = getattr(machine, test_attr)
             np.testing.assert_allclose(