diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py
index 1f2e8c38e68178ad3613a167b5125c656d101bca..cb5680af7674382d7b7718c339f90e29445c948c 100644
--- a/bob/learn/em/factor_analysis.py
+++ b/bob/learn/em/factor_analysis.py
@@ -4,12 +4,14 @@
 
 import logging
 
+import dask
 import numpy as np
 
 from sklearn.base import BaseEstimator
 from sklearn.utils.multiclass import unique_labels
 
 from .gmm import GMMMachine
+from .kmeans import check_and_persist_dask_input
 from .linear_scoring import linear_scoring
 
 logger = logging.getLogger(__name__)
@@ -204,12 +206,8 @@ class FactorAnalysisBase(BaseEstimator):
         """
 
         if self.ubm is None:
-            logger.info("FA: Creating a new GMMMachine.")
+            logger.info("FA: Creating a new GMMMachine and training it.")
             self.ubm = GMMMachine(**self.ubm_kwargs)
-
-        # Train the UBM if not already trained
-        if self.ubm._means is None:
-            logger.info(f"FA: Training the UBM with {self.ubm}.")
             self.ubm.fit(X)  # GMMMachine.fit takes non-labeled data
 
         # Initializing the state matrix
@@ -1153,8 +1151,14 @@ class FactorAnalysisBase(BaseEstimator):
         return self.score_using_stats(model, self.ubm.transform(data))
 
     def fit(self, X, y):
+        input_is_dask, X = check_and_persist_dask_input(X)
         self.initialize(X)
-        stats = [self.ubm.transform(xx) for xx in X]
+        if input_is_dask:
+            stats = [dask.delayed(self.ubm.transform)(xx) for xx in X]
+            stats = dask.compute(*stats)
+        else:
+            stats = [self.ubm.transform(xx) for xx in X]
+        del X  # we don't need to persist X anymore
         return self.fit_using_stats(stats, y)
 
 
@@ -1269,7 +1273,7 @@ class ISVMachine(FactorAnalysisBase):
         # TODO: Point of MAP-REDUCE
         n_acc, f_acc = self.initialize_using_stats(X, y)
         for i in range(self.em_iterations):
-            logger.info("U Training: Iteration %d", i)
+            logger.info("U Training: Iteration %d", i + 1)
             # TODO: Point of MAP-REDUCE
             acc_U_A1, acc_U_A2 = self.e_step(X, y, n_acc, f_acc)
             self.m_step(acc_U_A1, acc_U_A2)
@@ -1309,7 +1313,7 @@ class ISVMachine(FactorAnalysisBase):
         latent_x, _, latent_z = self.initialize_XYZ(y)
         latent_y = None
         for i in range(iterations):
-            logger.info("Enrollment: Iteration %d", i)
+            logger.info("Enrollment: Iteration %d", i + 1)
             latent_x = self.update_x(X, y, UProd, latent_x, latent_y, latent_z)
             latent_z = self.update_z(
                 X, y, latent_x, latent_y, latent_z, n_acc, f_acc
@@ -1738,7 +1742,7 @@ class JFAMachine(FactorAnalysisBase):
         latent_x, latent_y, latent_z = self.initialize_XYZ(y)
 
         for i in range(iterations):
-            logger.info("Enrollment: Iteration %d", i)
+            logger.info("Enrollment: Iteration %d", i + 1)
             latent_y = self.update_y(
                 X, y, VProd, latent_x, latent_y, latent_z, n_acc, f_acc
             )
@@ -1803,7 +1807,7 @@ class JFAMachine(FactorAnalysisBase):
 
         # Updating V
         for i in range(self.em_iterations):
-            logger.info("V Training: Iteration %d", i)
+            logger.info("V Training: Iteration %d", i + 1)
             # TODO: Point of MAP-REDUCE
             acc_V_A1, acc_V_A2 = self.e_step_v(X, y, n_acc, f_acc)
             self.m_step_v(acc_V_A1, acc_V_A2)
@@ -1811,7 +1815,7 @@ class JFAMachine(FactorAnalysisBase):
 
         # Updating U
         for i in range(self.em_iterations):
-            logger.info("U Training: Iteration %d", i)
+            logger.info("U Training: Iteration %d", i + 1)
             # TODO: Point of MAP-REDUCE
             acc_U_A1, acc_U_A2 = self.e_step_u(X, y, latent_y)
             self.m_step_u(acc_U_A1, acc_U_A2)
@@ -1820,7 +1824,7 @@ class JFAMachine(FactorAnalysisBase):
 
         # Updating D
         for i in range(self.em_iterations):
-            logger.info("D Training: Iteration %d", i)
+            logger.info("D Training: Iteration %d", i + 1)
             # TODO: Point of MAP-REDUCE
             acc_D_A1, acc_D_A2 = self.e_step_d(
                 X, y, latent_x, latent_y, n_acc, f_acc
diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index 831d3fc64a0aa0ebbc44cc2c1ab59a3b757ffb3d..ce8581a1380da416936452b21dd6c3cfc9175f27 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -855,24 +855,7 @@ class GMMMachine(BaseEstimator):
 
     def transform(self, X, **kwargs):
         """Returns the statistics for `X`."""
-        input_is_dask, X = check_and_persist_dask_input(X)
-
-        if input_is_dask:
-            stats = [
-                dask.delayed(e_step)(
-                    data=xx,
-                    machine=self,
-                )
-                for xx in X
-            ]
-            stats = functools.reduce(operator.iadd, stats)
-            stats = stats.compute()
-        else:
-            stats = e_step(
-                data=X,
-                machine=self,
-            )
-        return stats
+        return e_step(data=X, machine=self)
 
     def _more_tags(self):
         return {
diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py
index ca1bd2a167684cbf8786cc0a10845770157a2118..3b9ca21dc871a81477b80644777f94b0a327f8af 100644
--- a/bob/learn/em/test/test_factor_analysis.py
+++ b/bob/learn/em/test/test_factor_analysis.py
@@ -9,6 +9,9 @@ import numpy as np
 
 from bob.learn.em import GMMMachine, GMMStats, ISVMachine, JFAMachine
 
+from .test_gmm import multiprocess_dask_client
+from .test_kmeans import to_dask_array, to_numpy
+
 # Define Training set and initial values for tests
 F1 = np.array(
     [
@@ -443,37 +446,34 @@ def test_ISVMachine():
     np.testing.assert_allclose(score, score_ref, atol=eps)
 
 
+def _create_ubm_prior(means):
+    # Creating a fake prior with 2 gaussians
+    prior_gmm = GMMMachine(2)
+    prior_gmm.means = means.copy()
+    # All nice and round diagonal covariance
+    prior_gmm.variances = np.ones((2, 3)) * 0.5
+    prior_gmm.weights = np.array([0.3, 0.7])
+    return prior_gmm
+
+
 def test_ISV_JFA_fit():
     np.random.seed(10)
     data_class1 = np.random.normal(0, 0.5, (10, 3))
     data_class2 = np.random.normal(-0.2, 0.2, (10, 3))
     data = np.concatenate([data_class1, data_class2], axis=0)
     labels = [0] * 10 + [1] * 10
-
-    # Creating a fake prior with 2 gaussians
-    prior_gmm = GMMMachine(2)
-    prior_gmm.means = np.vstack(
+    means = np.vstack(
         (np.random.normal(0, 0.5, (1, 3)), np.random.normal(1, 0.5, (1, 3)))
     )
-    # All nice and round diagonal covariance
-    prior_gmm.variances = np.ones((2, 3)) * 0.5
-    prior_gmm.weights = np.array([0.3, 0.7])
 
     for prior, machine_type, ref in [
         (
             None,
             "isv",
-            [
-                [0.02619036, 0.07607595],
-                [-0.02570657, -0.07451667],
-                [-0.0430513, -0.12514552],
-                [-0.09729266, -0.28582205],
-                [-0.01035388, -0.03041718],
-                [0.0733034, 0.21534741],
-            ],
+            0.0,
         ),
         (
-            prior_gmm,
+            True,
             "isv",
             [
                 [-0.02361267, 0.0157274],
@@ -497,16 +497,16 @@ def test_ISV_JFA_fit():
             None,
             "jfa",
             [
-                [-1.72285693e-01, 1.47171193e-01],
-                [-1.08402014e-01, 9.25999920e-02],
-                [1.55349449e-02, -1.32703786e-02],
-                [2.13389657e-04, -1.82283334e-04],
-                [1.84127661e-05, -1.57286929e-05],
-                [-1.90492196e-04, 1.62723691e-04],
+                [-0.04687046, -0.06302095],
+                [-0.04380423, -0.05889816],
+                [-0.02083793, -0.0280182],
+                [-0.04728452, -0.06357768],
+                [-0.04371283, -0.05877527],
+                [-0.0203464, -0.0273573],
             ],
         ),
         (
-            prior_gmm,
+            True,
             "jfa",
             [
                 [6.54547662e-03, 1.98699266e-04],
@@ -526,35 +526,40 @@ def test_ISV_JFA_fit():
         ),
     ]:
         ref = np.asarray(ref)
-        ubm_kwargs = dict(n_gaussians=2) if prior is None else None
 
         # Doing the training
-        if machine_type == "isv":
-            machine = ISVMachine(
-                2,
-                ubm=prior,
+        for transform in (to_numpy, to_dask_array):
+            data, labels = transform(data, labels)
+
+            if prior is None:
+                ubm = None
+                ubm_kwargs = dict(n_gaussians=2, ubm=_create_ubm_prior(means))
+            else:
+                ubm = _create_ubm_prior(means)
+                ubm_kwargs = None
+
+            machine_kwargs = dict(
+                ubm=ubm,
                 relevance_factor=4,
                 em_iterations=50,
                 ubm_kwargs=ubm_kwargs,
                 seed=10,
             )
-            test_attr = "U"
-        else:
-            machine = JFAMachine(
-                2,
-                2,
-                ubm=prior,
-                relevance_factor=4,
-                em_iterations=50,
-                ubm_kwargs=ubm_kwargs,
-                seed=10,
+
+            if machine_type == "isv":
+                machine = ISVMachine(2, **machine_kwargs)
+                test_attr = "U"
+            else:
+                machine = JFAMachine(2, 2, **machine_kwargs)
+                test_attr = "V"
+
+            with multiprocess_dask_client():
+                machine.fit(data, labels)
+
+            arr = getattr(machine, test_attr)
+            np.testing.assert_allclose(
+                arr,
+                ref,
+                atol=1e-7,
+                err_msg=f"Test failed with prior={prior} and machine_type={machine_type} and transform={transform}",
             )
-            test_attr = "V"
-        machine.fit(data, labels)
-        arr = getattr(machine, test_attr)
-        np.testing.assert_allclose(
-            arr,
-            ref,
-            atol=1e-7,
-            err_msg=f"Test failed with prior={prior} and machine_type={machine_type}",
-        )
diff --git a/doc/guide.rst b/doc/guide.rst
index a94857f099b9235e5e946b2510ae13c9141f13f5..537d6be46d7bf8e290ce543a2096b8a1f94153c0 100644
--- a/doc/guide.rst
+++ b/doc/guide.rst
@@ -273,7 +273,7 @@ prior GMM.
     >>> # Training a GMM with 2 Gaussians of dimension 3
     >>> prior_gmm = bob.learn.em.GMMMachine(2).fit(data)
     >>> # Creating the container
-    >>> gmm_stats = prior_gmm.acc_statistics(data)
+    >>> gmm_stats = prior_gmm.transform(data)
     >>> # Printing the responsibilities
     >>> print(gmm_stats.n/gmm_stats.t)
      [0.6  0.4]
@@ -335,7 +335,7 @@ The snippet bellow shows how to:
    >>> # The input the the ISV Training is the statistics of the GMM
    >>> # Here we are creating a GMMStats for each datapoints, which is NOT usual,
    >>> # but it is done for testing purposes
-   >>> gmm_stats = [ubm.acc_statistics(x[np.newaxis]) for x in X]
+   >>> gmm_stats = [ubm.transform(x[np.newaxis]) for x in X]
 
    >>> # Finally doing the ISV training with U subspace with dimension of 2
    >>> isv_machine = bob.learn.em.ISVMachine(ubm, r_U=2).fit(gmm_stats, y)
@@ -410,7 +410,7 @@ such session variability model.
    >>> # The input the the JFA Training is the statistics of the GMM
    >>> # Here we are creating a GMMStats for each datapoints, which is NOT usual,
    >>> # but it is done for testing purposes
-   >>> gmm_stats = [ubm.acc_statistics(x[np.newaxis]) for x in X]
+   >>> gmm_stats = [ubm.transform(x[np.newaxis]) for x in X]
 
    >>> # Finally doing the JFA training with U and V subspaces with dimension of 2
    >>> jfa_machine = bob.learn.em.JFAMachine(ubm, r_U=2, r_V=2).fit(gmm_stats, y)
diff --git a/doc/plot/plot_ISV.py b/doc/plot/plot_ISV.py
index cc7b5593057db9d6802e2b47598eab32c9749fc2..f3a9c39617ddfbce6970078fc45dae99e66b750a 100644
--- a/doc/plot/plot_ISV.py
+++ b/doc/plot/plot_ISV.py
@@ -8,37 +8,6 @@ import bob.learn.em
 np.random.seed(2)  # FIXING A SEED
 
 
-def isv_train(features, ubm):
-    """
-    Train U matrix
-
-    **Parameters**
-      features: List of :py:class:`bob.learn.em.GMMStats` organized by class
-
-      n_gaussians: UBM (:py:class:`bob.learn.em.GMMMachine`)
-
-    """
-
-    stats = []
-    for user in features:
-        user_stats = []
-        for f in user:
-            s = bob.learn.em.GMMStats(ubm.shape[0], ubm.shape[1])
-            ubm.acc_statistics(f, s)
-            user_stats.append(s)
-        stats.append(user_stats)
-
-    relevance_factor = 4
-    subspace_dimension_of_u = 1
-
-    isvbase = bob.learn.em.ISVBase(ubm, subspace_dimension_of_u)
-    trainer = bob.learn.em.ISVTrainer(relevance_factor)
-    # trainer.rng = bob.core.random.mt19937(int(self.init_seed))
-    bob.learn.em.train(trainer, isvbase, stats, max_iterations=50)
-
-    return isvbase
-
-
 # GENERATING DATA
 iris_data = load_iris()
 X = np.column_stack((iris_data.data[:, 0], iris_data.data[:, 3]))
@@ -72,19 +41,12 @@ ubm.variances = np.array(
 
 ubm.weights = np.array([0.36, 0.36, 0.28])
 
-gmm_stats = [ubm.acc_statistics(x[np.newaxis]) for x in X]
-isv_machine = bob.learn.em.ISVMachine(ubm, r_U, em_iterations=50)
+isv_machine = bob.learn.em.ISVMachine(r_U, em_iterations=50, ubm=ubm)
 isv_machine.U = np.array(
     [[-0.150035, -0.44441, -1.67812, 2.47621, -0.52885, 0.659141]]
 ).T
 
-isv_machine = isv_machine.fit(gmm_stats, y)
-
-# gmm_stats = [ubm.acc_statistics(x) for x in [setosa, versicolor, virginica]]
-# isv_machine = bob.learn.em.ISVMachine(ubm, r_U).fit(gmm_stats, [0, 1, 2])
-
-
-# isvbase = isv_train([setosa, versicolor, virginica], ubm)
+isv_machine = isv_machine.fit(X, y)
 
 # Variability direction
 u0 = isv_machine.U[0:2, 0] / np.linalg.norm(isv_machine.U[0:2, 0])
diff --git a/doc/plot/plot_JFA.py b/doc/plot/plot_JFA.py
index 5c19474a49322178bd6f9ffb8bc3f1788f684355..6cfa4bc1051ce917448db37b9d44d7721a41385e 100644
--- a/doc/plot/plot_JFA.py
+++ b/doc/plot/plot_JFA.py
@@ -24,7 +24,7 @@ def isv_train(features, ubm):
         user_stats = []
         for f in user:
             s = bob.learn.em.GMMStats(ubm.shape[0], ubm.shape[1])
-            ubm.acc_statistics(f, s)
+            ubm.transform(f, s)
             user_stats.append(s)
         stats.append(user_stats)
 
@@ -75,7 +75,7 @@ ubm.variances = np.array(
 ubm.weights = np.array([0.36, 0.36, 0.28])
 # .fit(X)
 
-gmm_stats = [ubm.acc_statistics(x[np.newaxis]) for x in X]
+gmm_stats = [ubm.transform(x[np.newaxis]) for x in X]
 jfa_machine = bob.learn.em.JFAMachine(ubm, r_U, r_V, em_iterations=50)
 
 # Initializing with old bob initialization
@@ -94,7 +94,7 @@ jfa_machine = jfa_machine.fit(gmm_stats, y)
 
 # .fit(gmm_stats, y)
 
-# gmm_stats = [ubm.acc_statistics(x) for x in [setosa, versicolor, virginica]]
+# gmm_stats = [ubm.transform(x) for x in [setosa, versicolor, virginica]]
 # jfa_machine = bob.learn.em.JFAMachine(ubm, r_U, r_V).fit(gmm_stats, [0, 1, 2])