diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py
index 09d158cb13dfbb9beec1d582dfd88225ea98cc17..6dd609e3f7b0aa71c5999d8752128d1d24c895ba 100644
--- a/bob/learn/em/factor_analysis.py
+++ b/bob/learn/em/factor_analysis.py
@@ -9,30 +9,52 @@ import operator
 import dask
 import numpy as np
 
+from dask.array.core import Array
 from dask.delayed import Delayed
 from sklearn.base import BaseEstimator
 from sklearn.utils.multiclass import unique_labels
 
 from .gmm import GMMMachine
 from .linear_scoring import linear_scoring
-from .utils import array_to_delayed_list, check_and_persist_dask_input
+from .utils import check_and_persist_dask_input
 
 logger = logging.getLogger(__name__)
 
 
-def is_input_delayed(X):
+def is_input_dask_nested(X):
     """
-    Check if the input is a list of dask delayed objects.
+    Check if the input is a dask delayed or array or a (nested) list of dask
+    delayed or array objects.
     """
     if isinstance(X, (list, tuple)):
-        return is_input_delayed(X[0])
+        return is_input_dask_nested(X[0])
 
-    if isinstance(X, Delayed):
+    if isinstance(X, (Delayed, Array)):
         return True
     else:
         return False
 
 
+def check_dask_input_samples_per_class(X, y):
+    input_is_dask = is_input_dask_nested(X)
+
+    if input_is_dask:
+        n_classes = len(y)
+        n_samples_per_class = [len(yy) for yy in y]
+    else:
+        unique_labels_y = unique_labels(y)
+        n_classes = len(unique_labels_y)
+        n_samples_per_class = [
+            sum(y == class_id) for class_id in unique_labels_y
+        ]
+    return input_is_dask, n_classes, n_samples_per_class
+
+
+def reduce_iadd(a):
+    """Reduces a list by adding all elements into the first element"""
+    return functools.reduce(operator.iadd, a)
+
+
 def mult_along_axis(A, B, axis):
     """
     Magic function to multiply two arrays along a given axis.
@@ -237,20 +259,19 @@ class FactorAnalysisBase(BaseEstimator):
         # Accumulating 0th and 1st order statistics
         # https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/ISVTrainer.cpp#L68
 
-        if is_input_delayed(ubm_projected_X):
+        if is_input_dask_nested(ubm_projected_X):
             n_acc = [
                 dask.delayed(self._sum_n_statistics)(xx, yy, n_classes)
                 for xx, yy in zip(ubm_projected_X, y)
             ]
-            n_acc = dask.compute(*n_acc)
-            n_acc = functools.reduce(operator.iadd, n_acc)
 
             f_acc = [
                 dask.delayed(self._sum_f_statistics)(xx, yy, n_classes)
                 for xx, yy in zip(ubm_projected_X, y)
             ]
-            f_acc = dask.compute(*f_acc)
-            f_acc = functools.reduce(operator.iadd, f_acc)
+            n_acc, f_acc = dask.compute(n_acc, f_acc)
+            n_acc = reduce_iadd(n_acc)
+            f_acc = reduce_iadd(f_acc)
         else:
             # 0th order stats
             n_acc = self._sum_n_statistics(ubm_projected_X, y, n_classes)
@@ -374,7 +395,7 @@ class FactorAnalysisBase(BaseEstimator):
                 Class id to return the statistics for
         """
         X = np.array(X)
-        return list(X[np.where(np.array(y) == i)[0]])
+        return list(X[np.array(y) == i])
 
     """
     Estimating U and x
@@ -406,7 +427,7 @@ class FactorAnalysisBase(BaseEstimator):
         # TODO: make the invertion matrix function as a parameter
         return np.linalg.inv(I + (UProd * n_i[:, None, None]).sum(axis=0))
 
-    def _computefn_x_ih(self, x_i, latent_z_i=None, latent_y_i=None):
+    def _compute_fn_x_ih(self, x_i, latent_z_i=None, latent_y_i=None):
         """
         Computes Fn_x_ih = N_{i,h}*(o_{i,h} - m - D*z_{i} - V*y_{i})
         Check equation (29) in [McCool2013]_
@@ -448,8 +469,8 @@ class FactorAnalysisBase(BaseEstimator):
 
         return fn_x_ih
 
-    def update_x(
-        self, X, y, n_classes, UProd, latent_x, latent_y=None, latent_z=None
+    def compute_latent_x(
+        self, *, X, y, n_classes, UProd, latent_y=None, latent_z=None
     ):
         """
         Computes a new math:`E[x]` See equation (29) in [McCool2013]_
@@ -464,15 +485,9 @@ class FactorAnalysisBase(BaseEstimator):
             y: list of ints
                 List of corresponding labels
 
-            n_classes: int
-                Number of classes
-
             UProd: array
                 Matrix containing U_c.T*inv(Sigma_c) @ U_c.T
 
-            latent_x: array
-                E(x) latent variable
-
             latent_y: array
                 E(y) latent variable
 
@@ -482,27 +497,50 @@ class FactorAnalysisBase(BaseEstimator):
         Returns
         -------
             Returns the new latent_x
-
         """
 
         # U.T @ inv(Sigma) - See Eq(37)
         UTinvSigma = self._U.T / self.variance_supervector
 
-        session_offsets = np.zeros(n_classes)
+        if is_input_dask_nested(X):
+            latent_x = [
+                dask.delayed(self._compute_latent_x_per_class)(
+                    X_i=X_i,
+                    UProd=UProd,
+                    UTinvSigma=UTinvSigma,
+                    latent_y_i=latent_y[y_i] if latent_y is not None else None,
+                    latent_z_i=latent_z[y_i] if latent_z is not None else None,
+                )
+                for y_i, X_i in enumerate(X)
+            ]
+            latent_x = dask.compute(*latent_x)
+        else:
+            latent_x = [None] * n_classes
+            for y_i in unique_labels(y):
+                latent_x[y_i] = self._compute_latent_x_per_class(
+                    X_i=np.array(X)[y == y_i],
+                    UProd=UProd,
+                    UTinvSigma=UTinvSigma,
+                    latent_y_i=latent_y[y_i] if latent_y is not None else None,
+                    latent_z_i=latent_z[y_i] if latent_z is not None else None,
+                )
+
+        return latent_x
+
+    def _compute_latent_x_per_class(
+        self, *, X_i, UProd, UTinvSigma, latent_y_i, latent_z_i
+    ):
         # For each sample
-        for x_i, y_i in zip(X, y):
+        latent_x_i = []
+        for x_i in X_i:
             id_plus_prod_ih = self._compute_id_plus_u_prod_ih(x_i, UProd)
-            latent_z_i = latent_z[y_i] if latent_z is not None else None
-            latent_y_i = latent_y[y_i] if latent_y is not None else None
 
-            fn_x_ih = self._computefn_x_ih(
+            fn_x_ih = self._compute_fn_x_ih(
                 x_i, latent_z_i=latent_z_i, latent_y_i=latent_y_i
             )
-            latent_x[y_i][:, int(session_offsets[y_i])] = id_plus_prod_ih @ (
-                UTinvSigma @ fn_x_ih
-            )
-            session_offsets[y_i] += 1
-        return latent_x
+            latent_x_i.append(id_plus_prod_ih @ (UTinvSigma @ fn_x_ih))
+        latent_x_i = np.swapaxes(latent_x_i, 0, 1)
+        return latent_x_i
 
     def update_U(self, acc_U_A1, acc_U_A2):
         """
@@ -612,7 +650,7 @@ class FactorAnalysisBase(BaseEstimator):
                 id_plus_prod_ih = self._compute_id_plus_u_prod_ih(x_i, UProd)
                 latent_z_i = latent_z[y_i] if latent_z is not None else None
                 latent_y_i = latent_y[y_i] if latent_y is not None else None
-                fn_x_ih = self._computefn_x_ih(
+                fn_x_ih = self._compute_fn_x_ih(
                     x_i, latent_y_i=latent_y_i, latent_z_i=latent_z_i
                 )
 
@@ -832,7 +870,7 @@ class FactorAnalysisBase(BaseEstimator):
 
         return acc_D_A1, acc_D_A2
 
-    def initialize_XYZ(self, y, n_classes):
+    def initialize_XYZ(self, n_samples_per_class):
         """
         Initialize E[x], E[y], E[z] state variables
 
@@ -841,7 +879,7 @@ class FactorAnalysisBase(BaseEstimator):
 
 
         Eq. (37)
-        latent_y =
+        latent_y = (n_classes, r_V) or None
 
         Eq. (36)
         latent_x = (n_classes, r_U, n_sessions)
@@ -851,16 +889,10 @@ class FactorAnalysisBase(BaseEstimator):
         # x (Eq. 36)
         # (n_classes, r_U,  n_samples )
         latent_x = []
-        for y_i in set(y):
-            latent_x.append(
-                np.zeros(
-                    (
-                        self.r_U,
-                        np.sum(y == y_i),
-                    )
-                )
-            )
+        for n_s in n_samples_per_class:
+            latent_x.append(np.zeros((self.r_U, n_s)))
 
+        n_classes = len(n_samples_per_class)
         latent_y = (
             np.zeros((n_classes, self.r_V))
             if self.r_V and self.r_V > 0
@@ -876,7 +908,17 @@ class FactorAnalysisBase(BaseEstimator):
     """
 
     def update_y(
-        self, X, y, n_classes, VProd, latent_x, latent_y, latent_z, n_acc, f_acc
+        self,
+        *,
+        X,
+        y,
+        n_classes,
+        VProd,
+        latent_x,
+        latent_y,
+        latent_z,
+        n_acc,
+        f_acc,
     ):
         """
         Computes a new math:`E[y]` See equation (30) in [McCool2013]_
@@ -915,20 +957,55 @@ class FactorAnalysisBase(BaseEstimator):
         # V.T / sigma
         VTinvSigma = self._V.T / self.variance_supervector
 
-        # Loops over the labels
-        for label in range(n_classes):
-            id_plus_v_prod_i = self._compute_id_plus_vprod_i(
-                n_acc[label], VProd
-            )
-            X_i = self._get_statistics_by_class_id(X, y, label)
-            fn_y_i = self._compute_fn_y_i(
-                X_i,
-                latent_x[label],
-                latent_z[label],
-                n_acc[label],
-                f_acc[label],
-            )
-            latent_y[label] = (VTinvSigma @ fn_y_i) @ id_plus_v_prod_i
+        if is_input_dask_nested(X):
+            latent_y = [
+                dask.delayed(self._latent_y_per_class)(
+                    X_i=X_i,
+                    n_acc_i=n_acc[label],
+                    f_acc_i=f_acc[label],
+                    VProd=VProd,
+                    VTinvSigma=VTinvSigma,
+                    latent_x_i=latent_x[label],
+                    latent_z_i=latent_z[label],
+                )
+                for label, X_i in enumerate(X)
+            ]
+            latent_y = dask.compute(*latent_y)
+        else:
+            # Loops over the labels
+            for label in range(n_classes):
+                X_i = self._get_statistics_by_class_id(X, y, label)
+                latent_y[label] = self._latent_y_per_class(
+                    X_i=X_i,
+                    n_acc_i=n_acc[label],
+                    f_acc_i=f_acc[label],
+                    VProd=VProd,
+                    VTinvSigma=VTinvSigma,
+                    latent_x_i=latent_x[label],
+                    latent_z_i=latent_z[label],
+                )
+        return latent_y
+
+    def _latent_y_per_class(
+        self,
+        *,
+        X_i,
+        n_acc_i,
+        f_acc_i,
+        VProd,
+        VTinvSigma,
+        latent_x_i,
+        latent_z_i,
+    ):
+        id_plus_v_prod_i = self._compute_id_plus_vprod_i(n_acc_i, VProd)
+        fn_y_i = self._compute_fn_y_i(
+            X_i,
+            latent_x_i,
+            latent_z_i,
+            n_acc_i,
+            f_acc_i,
+        )
+        latent_y = (VTinvSigma @ fn_y_i) @ id_plus_v_prod_i
         return latent_y
 
     def _compute_id_plus_vprod_i(self, n_acc_i, VProd):
@@ -1197,32 +1274,27 @@ class FactorAnalysisBase(BaseEstimator):
         return self.score_using_stats(model, self.ubm.transform(data))
 
     def fit(self, X, y):
-        input_is_dask, X = check_and_persist_dask_input(X)
-        self.initialize(X)
+        input_is_dask, X = check_and_persist_dask_input(X, persist=False)
         y = np.squeeze(np.asarray(y))
+
+        self.initialize(X)
+
         if input_is_dask:
-            chunks = X.chunks[0]
-            # chunk the y array similar to X into a list of numpy arrays
-            i, new_y = 0, []
-            for chunk in chunks:
-                new_y.append(y[i : i + chunk])
-                i += chunk
-            y = new_y
-            X = array_to_delayed_list(X, input_is_dask)
+            # split the X array based on the classes
+            X_new, y_new = [], []
+            for class_id in unique_labels(y):
+                class_indices = y == class_id
+                X_new.append(X[class_indices])
+                y_new.append(y[class_indices])
+            X, y = X_new, y_new
+
             stats = [
-                dask.delayed(self.ubm.transform_one_by_one)(xx).persist()
+                dask.delayed(self.ubm.stats_per_sample)(xx).persist()
                 for xx in X
             ]
-            # stats = dask.compute(*stats)
-            # stats = functools.reduce(operator.iadd, stats)
         else:
-            stats = self.ubm.transform_one_by_one(X)
-        del X  # we don't need to persist X anymore
+            stats = self.ubm.stats_per_sample(X)
 
-        # if input_is_dask:
-        #     dask.compute(dask.delayed(self.fit_using_stats)(stats, y))
-        # else:
-        #     self.fit_using_stats(stats, y)
         self.fit_using_stats(stats, y)
         return self
 
@@ -1278,21 +1350,22 @@ class ISVMachine(FactorAnalysisBase):
             **kwargs,
         )
 
-    def e_step(self, X, y, n_classes, n_acc, f_acc):
+    def e_step(self, X, y, n_samples_per_class, n_acc, f_acc):
         """
         E-step of the EM algorithm
         """
         # self.initialize_XYZ(y)
         UProd = self._compute_uprod()
-        latent_x, _, latent_z = self.initialize_XYZ(y, n_classes=n_classes)
+        _, _, latent_z = self.initialize_XYZ(
+            n_samples_per_class=n_samples_per_class
+        )
         latent_y = None
 
-        latent_x = self.update_x(
+        latent_x = self.compute_latent_x(
             X=X,
             y=y,
-            n_classes=n_classes,
+            n_classes=len(n_samples_per_class),
             UProd=UProd,
-            latent_x=latent_x,
         )
         latent_z = self.update_z(
             X=X,
@@ -1325,10 +1398,10 @@ class ISVMachine(FactorAnalysisBase):
 
         """
         acc_U_A1 = [acc[0] for acc in acc_U_A1_acc_U_A2_list]
-        acc_U_A1 = functools.reduce(operator.iadd, acc_U_A1)
+        acc_U_A1 = reduce_iadd(acc_U_A1)
 
         acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list]
-        acc_U_A2 = functools.reduce(operator.iadd, acc_U_A2)
+        acc_U_A2 = reduce_iadd(acc_U_A2)
 
         self.update_U(acc_U_A1, acc_U_A2)
 
@@ -1349,38 +1422,38 @@ class ISVMachine(FactorAnalysisBase):
             Returns self.
 
         """
+        (
+            input_is_dask,
+            n_classes,
+            n_samples_per_class,
+        ) = check_dask_input_samples_per_class(X, y)
 
-        y = np.asarray(y)
-        n_classes = self.estimate_number_of_classes(y)
         n_acc, f_acc = self.initialize_using_stats(X, y, n_classes)
 
         for i in range(self.em_iterations):
             logger.info("U Training: Iteration %d", i + 1)
-            # TODO: Point of MAP-REDUCE
-            if is_input_delayed(X):
-                acc_U_A1_acc_U_A2_list = [
+            if input_is_dask:
+                e_step_output = [
                     dask.delayed(self.e_step)(
                         X=xx,
                         y=yy,
-                        n_classes=n_classes,
+                        n_samples_per_class=n_samples_per_class,
                         n_acc=n_acc,
                         f_acc=f_acc,
                     )
                     for xx, yy in zip(X, y)
                 ]
-                delayed_em_step = dask.delayed(self.m_step)(
-                    acc_U_A1_acc_U_A2_list
-                )
+                delayed_em_step = dask.delayed(self.m_step)(e_step_output)
                 dask.compute(delayed_em_step)
             else:
-                acc_U_A1, acc_U_A2 = self.e_step(
+                e_step_output = self.e_step(
                     X=X,
                     y=y,
-                    n_classes=n_classes,
+                    n_samples_per_class=n_samples_per_class,
                     n_acc=n_acc,
                     f_acc=f_acc,
                 )
-                self.m_step([(acc_U_A1, acc_U_A2)])
+                self.m_step([e_step_output])
 
         return self
 
@@ -1415,16 +1488,15 @@ class ISVMachine(FactorAnalysisBase):
         f_acc = self._sum_f_statistics(X, y=y, n_classes=1)
 
         UProd = self._compute_uprod()
-        latent_x, _, latent_z = self.initialize_XYZ(y, n_classes=1)
+        _, _, latent_z = self.initialize_XYZ(n_samples_per_class=[len(X)])
         latent_y = None
         for i in range(iterations):
             logger.info("Enrollment: Iteration %d", i + 1)
-            latent_x = self.update_x(
+            latent_x = self.compute_latent_x(
                 X=X,
                 y=y,
                 n_classes=1,
                 UProd=UProd,
-                latent_x=latent_x,
                 latent_y=latent_y,
                 latent_z=latent_z,
             )
@@ -1550,7 +1622,7 @@ class JFAMachine(FactorAnalysisBase):
             **kwargs,
         )
 
-    def e_step_v(self, X, y, n_classes, n_acc, f_acc):
+    def e_step_v(self, X, y, n_samples_per_class, n_acc, f_acc):
         """
         ISV E-step for the V matrix.
 
@@ -1587,11 +1659,12 @@ class JFAMachine(FactorAnalysisBase):
         VProd = self._compute_vprod()
 
         latent_x, latent_y, latent_z = self.initialize_XYZ(
-            y, n_classes=n_classes
+            n_samples_per_class=n_samples_per_class
         )
 
         # UPDATE Y, X AND FINALLY Z
 
+        n_classes = len(n_samples_per_class)
         latent_y = self.update_y(
             X=X,
             y=y,
@@ -1617,7 +1690,7 @@ class JFAMachine(FactorAnalysisBase):
 
         return acc_V_A1, acc_V_A2
 
-    def m_step_v(self, acc_V_A1, acc_V_A2):
+    def m_step_v(self, acc_V_A1_acc_V_A2_list):
         """
         `V` Matrix M-step.
         This updates the `V` matrix
@@ -1632,6 +1705,11 @@ class JFAMachine(FactorAnalysisBase):
                 Accumulated statistics for V_A2(n_gaussians* feature_dimension, r_V)
 
         """
+        acc_V_A1 = [acc[0] for acc in acc_V_A1_acc_V_A2_list]
+        acc_V_A1 = reduce_iadd(acc_V_A1)
+
+        acc_V_A2 = [acc[1] for acc in acc_V_A1_acc_V_A2_list]
+        acc_V_A2 = reduce_iadd(acc_V_A2)
 
         # Inverting A1 over the zero axis
         # https://stackoverflow.com/questions/11972102/is-there-a-way-to-efficiently-invert-an-array-of-matrices-with-numpy
@@ -1647,7 +1725,7 @@ class JFAMachine(FactorAnalysisBase):
             (self.ubm.n_gaussians * self.feature_dimension, self.r_V)
         )
 
-    def finalize_v(self, X, y, n_classes, n_acc, f_acc):
+    def finalize_v(self, X, y, n_samples_per_class, n_acc, f_acc):
         """
         Compute for the last time `E[y]`
 
@@ -1678,11 +1756,12 @@ class JFAMachine(FactorAnalysisBase):
         VProd = self._compute_vprod()
 
         latent_x, latent_y, latent_z = self.initialize_XYZ(
-            y, n_classes=n_classes
+            n_samples_per_class=n_samples_per_class
         )
 
         # UPDATE Y, X AND FINALLY Z
 
+        n_classes = len(n_samples_per_class)
         latent_y = self.update_y(
             X=X,
             y=y,
@@ -1696,7 +1775,7 @@ class JFAMachine(FactorAnalysisBase):
         )
         return latent_y
 
-    def e_step_u(self, X, y, n_classes, latent_y):
+    def e_step_u(self, X, y, n_samples_per_class, latent_y):
         """
         ISV E-step for the U matrix.
 
@@ -1725,14 +1804,16 @@ class JFAMachine(FactorAnalysisBase):
         """
         # self.initialize_XYZ(y)
         UProd = self._compute_uprod()
-        latent_x, _, latent_z = self.initialize_XYZ(y, n_classes)
+        latent_x, _, latent_z = self.initialize_XYZ(
+            n_samples_per_class=n_samples_per_class
+        )
 
-        latent_x = self.update_x(
+        n_classes = len(n_samples_per_class)
+        latent_x = self.compute_latent_x(
             X=X,
             y=y,
             n_classes=n_classes,
             UProd=UProd,
-            latent_x=latent_x,
             latent_y=latent_y,
         )
 
@@ -1747,7 +1828,7 @@ class JFAMachine(FactorAnalysisBase):
 
         return acc_U_A1, acc_U_A2
 
-    def m_step_u(self, acc_U_A1, acc_U_A2):
+    def m_step_u(self, acc_U_A1_acc_U_A2_list):
         """
         `U` Matrix M-step.
         This updates the `U` matrix
@@ -1762,6 +1843,11 @@ class JFAMachine(FactorAnalysisBase):
                 Accumulated statistics for V_A2(n_gaussians* feature_dimension, r_V)
 
         """
+        acc_U_A1 = [acc[0] for acc in acc_U_A1_acc_U_A2_list]
+        acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list]
+
+        acc_U_A1 = reduce_iadd(acc_U_A1)
+        acc_U_A2 = reduce_iadd(acc_U_A2)
 
         self.update_U(acc_U_A1, acc_U_A2)
 
@@ -1769,7 +1855,7 @@ class JFAMachine(FactorAnalysisBase):
         self,
         X,
         y,
-        n_classes,
+        n_samples_per_class,
         latent_y,
     ):
         """
@@ -1797,20 +1883,24 @@ class JFAMachine(FactorAnalysisBase):
         """
 
         UProd = self._compute_uprod()
-        latent_x, _, _ = self.initialize_XYZ(y, n_classes=n_classes)
+        latent_x, _, _ = self.initialize_XYZ(
+            n_samples_per_class=n_samples_per_class
+        )
 
-        latent_x = self.update_x(
+        n_classes = len(n_samples_per_class)
+        latent_x = self.compute_latent_x(
             X=X,
             y=y,
             n_classes=n_classes,
             UProd=UProd,
-            latent_x=latent_x,
             latent_y=latent_y,
         )
 
         return latent_x
 
-    def e_step_d(self, X, y, n_classes, latent_x, latent_y, n_acc, f_acc):
+    def e_step_d(
+        self, X, y, n_samples_per_class, latent_x, latent_y, n_acc, f_acc
+    ):
         """
         ISV E-step for the U matrix.
 
@@ -1852,8 +1942,9 @@ class JFAMachine(FactorAnalysisBase):
                 Accumulated statistics for D_A2(n_gaussians* feature_dimension, )
 
         """
-
-        _, _, latent_z = self.initialize_XYZ(y, n_classes=n_classes)
+        _, _, latent_z = self.initialize_XYZ(
+            n_samples_per_class=n_samples_per_class
+        )
 
         latent_z = self.update_z(
             X,
@@ -1871,7 +1962,7 @@ class JFAMachine(FactorAnalysisBase):
 
         return acc_D_A1, acc_D_A2
 
-    def m_step_d(self, acc_D_A1, acc_D_A2):
+    def m_step_d(self, acc_D_A1_acc_D_A2_list):
         """
         `D` Matrix M-step.
         This updates the `D` matrix
@@ -1886,6 +1977,12 @@ class JFAMachine(FactorAnalysisBase):
                 Accumulated statistics for D_A2(n_gaussians* feature_dimension, )
 
         """
+        acc_D_A1 = [acc[0] for acc in acc_D_A1_acc_D_A2_list]
+        acc_D_A2 = [acc[1] for acc in acc_D_A1_acc_D_A2_list]
+
+        acc_D_A1 = reduce_iadd(acc_D_A1)
+        acc_D_A2 = reduce_iadd(acc_D_A2)
+
         self._D = acc_D_A2 / acc_D_A1
 
     def enroll_using_stats(self, X, iterations=1):
@@ -1915,7 +2012,9 @@ class JFAMachine(FactorAnalysisBase):
 
         UProd = self._compute_uprod()
         VProd = self._compute_vprod()
-        latent_x, latent_y, latent_z = self.initialize_XYZ(y, n_classes=1)
+        latent_x, latent_y, latent_z = self.initialize_XYZ(
+            n_samples_per_class=[len(X)]
+        )
 
         for i in range(iterations):
             logger.info("Enrollment: Iteration %d", i + 1)
@@ -1930,12 +2029,11 @@ class JFAMachine(FactorAnalysisBase):
                 n_acc=n_acc,
                 f_acc=f_acc,
             )
-            latent_x = self.update_x(
+            latent_x = self.compute_latent_x(
                 X=X,
                 y=y,
                 n_classes=1,
                 UProd=UProd,
-                latent_x=latent_x,
                 latent_y=latent_y,
                 latent_z=latent_z,
             )
@@ -1998,58 +2096,104 @@ class JFAMachine(FactorAnalysisBase):
         ):
             self.create_UVD()
 
-        y = np.asarray(y)
-        n_classes = self.estimate_number_of_classes(y)
+        (
+            input_is_dask,
+            n_classes,
+            n_samples_per_class,
+        ) = check_dask_input_samples_per_class(X, y)
 
-        # TODO: Point of MAP-REDUCE
         n_acc, f_acc = self.initialize_using_stats(X, y, n_classes=n_classes)
 
         # Updating V
         for i in range(self.em_iterations):
             logger.info("V Training: Iteration %d", i + 1)
-            # TODO: Point of MAP-REDUCE
-            acc_V_A1, acc_V_A2 = self.e_step_v(
-                X=X,
-                y=y,
-                n_classes=n_classes,
-                n_acc=n_acc,
-                f_acc=f_acc,
-            )
-            self.m_step_v(acc_V_A1, acc_V_A2)
+            if input_is_dask:
+                e_step_output = [
+                    dask.delayed(self.e_step_v)(
+                        X=xx,
+                        y=yy,
+                        n_samples_per_class=n_samples_per_class,
+                        n_acc=n_acc,
+                        f_acc=f_acc,
+                    )
+                    for xx, yy in zip(X, y)
+                ]
+                delayed_em_step = dask.delayed(self.m_step_v)(e_step_output)
+                dask.compute(delayed_em_step)
+            else:
+                e_step_output = self.e_step_v(
+                    X=X,
+                    y=y,
+                    n_samples_per_class=n_samples_per_class,
+                    n_acc=n_acc,
+                    f_acc=f_acc,
+                )
+                self.m_step_v([e_step_output])
         latent_y = self.finalize_v(
-            X=X, y=y, n_classes=n_classes, n_acc=n_acc, f_acc=f_acc
+            X=X,
+            y=y,
+            n_samples_per_class=n_samples_per_class,
+            n_acc=n_acc,
+            f_acc=f_acc,
         )
 
         # Updating U
         for i in range(self.em_iterations):
             logger.info("U Training: Iteration %d", i + 1)
-            # TODO: Point of MAP-REDUCE
-            acc_U_A1, acc_U_A2 = self.e_step_u(
-                X=X,
-                y=y,
-                n_classes=n_classes,
-                latent_y=latent_y,
-            )
-            self.m_step_u(acc_U_A1, acc_U_A2)
+            if input_is_dask:
+                e_step_output = [
+                    dask.delayed(self.e_step_u)(
+                        X=xx,
+                        y=yy,
+                        n_samples_per_class=n_samples_per_class,
+                        latent_y=latent_y,
+                    )
+                    for xx, yy in zip(X, y)
+                ]
+                delayed_em_step = dask.delayed(self.m_step_u)(e_step_output)
+                dask.compute(delayed_em_step)
+            else:
+                e_step_output = self.e_step_u(
+                    X=X,
+                    y=y,
+                    n_samples_per_class=n_samples_per_class,
+                    latent_y=latent_y,
+                )
+                self.m_step_u([e_step_output])
 
         latent_x = self.finalize_u(
-            X=X, y=y, n_classes=n_classes, latent_y=latent_y
+            X=X, y=y, n_samples_per_class=n_samples_per_class, latent_y=latent_y
         )
 
         # Updating D
         for i in range(self.em_iterations):
             logger.info("D Training: Iteration %d", i + 1)
-            # TODO: Point of MAP-REDUCE
-            acc_D_A1, acc_D_A2 = self.e_step_d(
-                X=X,
-                y=y,
-                n_classes=n_classes,
-                latent_x=latent_x,
-                latent_y=latent_y,
-                n_acc=n_acc,
-                f_acc=f_acc,
-            )
-            self.m_step_d(acc_D_A1, acc_D_A2)
+            if input_is_dask:
+                e_step_output = [
+                    dask.delayed(self.e_step_d)(
+                        X=xx,
+                        y=yy,
+                        n_samples_per_class=n_samples_per_class,
+                        latent_x=latent_x,
+                        latent_y=latent_y,
+                        n_acc=n_acc,
+                        f_acc=f_acc,
+                    )
+                    for xx, yy in zip(X, y)
+                ]
+                delayed_em_step = dask.delayed(self.m_step_d)(e_step_output)
+                dask.compute(delayed_em_step)
+            else:
+                e_step_output = self.e_step_d(
+                    X=X,
+                    y=y,
+                    n_samples_per_class=n_samples_per_class,
+                    latent_x=latent_x,
+                    latent_y=latent_y,
+                    n_acc=n_acc,
+                    f_acc=f_acc,
+                )
+                self.m_step_d([e_step_output])
 
         return self
 
diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index 394a93082ef290b48524bd5f1da1338b8d1c1418..25159f17aaa95f68f1795c6306f9ed57a9181b71 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -854,7 +854,7 @@ class GMMMachine(BaseEstimator):
         """Returns the statistics for `X`."""
         return e_step(data=X, machine=self)
 
-    def transform_one_by_one(self, X):
+    def stats_per_sample(self, X):
         return [e_step(data=xx, machine=self) for xx in X]
 
     def _more_tags(self):
diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py
index 08ae5ab6f5c328f484de0ed2357954ce16de0ddb..8d82d0cca8380af217f78b66db90ccf7490a6516 100644
--- a/bob/learn/em/test/test_factor_analysis.py
+++ b/bob/learn/em/test/test_factor_analysis.py
@@ -555,13 +555,14 @@ def test_ISV_JFA_fit():
                 machine = JFAMachine(2, 2, **machine_kwargs)
                 test_attr = "V"
 
-            with multiprocess_dask_client():
-                machine.fit(data, labels)
+            err_msg = f"Test failed with prior={prior} and machine_type={machine_type} and transform={transform}"
+            # with multiprocess_dask_client():
+            machine.fit(data, labels)
 
             arr = getattr(machine, test_attr)
             np.testing.assert_allclose(
                 arr,
                 ref,
                 atol=1e-7,
-                err_msg=f"Test failed with prior={prior} and machine_type={machine_type} and transform={transform}",
+                err_msg=err_msg,
             )
diff --git a/bob/learn/em/utils.py b/bob/learn/em/utils.py
index 108e79753f5c85997fcbecf834fdbb4db37bd66c..be92e8d41258384e08b06a07d212d1da328e048f 100644
--- a/bob/learn/em/utils.py
+++ b/bob/learn/em/utils.py
@@ -7,11 +7,12 @@ import numpy as np
 logger = logging.getLogger(__name__)
 
 
-def check_and_persist_dask_input(data):
+def check_and_persist_dask_input(data, persist=True):
     # check if input is a dask array. If so, persist and rebalance data
     input_is_dask = False
     if isinstance(data, da.Array):
-        data: da.Array = data.persist()
+        if persist:
+            data: da.Array = data.persist()
         input_is_dask = True
         # if there is a dask distributed client, rebalance data
         try: