diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py
index 9b9a8894d56e530e661149e7d7f7e81760cb3926..f9b39cef2a344e552a190dff26929706a5c04b5e 100644
--- a/bob/learn/em/factor_analysis.py
+++ b/bob/learn/em/factor_analysis.py
@@ -9,7 +9,6 @@ import operator
 import dask
 import numpy as np
 
-from dask.array.core import Array
 from dask.delayed import Delayed
 from sklearn.base import BaseEstimator
 from sklearn.utils import check_consistent_length
@@ -23,14 +22,11 @@ logger = logging.getLogger(__name__)
 
 
 def is_input_dask_nested(X):
-    """
-    Check if the input is a dask delayed or array or a (nested) list of dask
-    delayed or array objects.
-    """
+    """Check if the input is a dask delayed or a (nested) list of dask delayed."""
     if isinstance(X, (list, tuple)):
         return is_input_dask_nested(X[0])
 
-    if isinstance(X, (Delayed, Array)):
+    if isinstance(X, Delayed):
         return True
     else:
         return False
@@ -40,20 +36,22 @@ def check_dask_input_samples_per_class(X, y):
     input_is_dask = is_input_dask_nested(X)
 
     if input_is_dask:
-        n_classes = len(y)
-        n_samples_per_class = [len(yy) for yy in y]
-    else:
-        unique_labels_y = unique_labels(y)
-        n_classes = len(unique_labels_y)
-        n_samples_per_class = [
-            sum(y == class_id) for class_id in unique_labels_y
-        ]
+        y = functools.reduce(lambda x1, x2: list(x1) + list(x2), y)
+    unique_labels_y = unique_labels(y)
+    n_classes = len(unique_labels_y)
+    n_samples_per_class = [sum(y == class_id) for class_id in unique_labels_y]
     return input_is_dask, n_classes, n_samples_per_class
 
 
-def reduce_iadd(a):
-    """Reduces a list by adding all elements into the first element"""
-    return functools.reduce(operator.iadd, a)
+def reduce_iadd(*args):
+    """Reduces one or several lists by adding all elements into the first element"""
+    ret = []
+    for a in args:
+        ret.append(functools.reduce(operator.iadd, a))
+
+    if len(ret) == 1:
+        return ret[0]
+    return ret
 
 
 def mult_along_axis(A, B, axis):
@@ -62,10 +60,6 @@ def mult_along_axis(A, B, axis):
     Taken from https://stackoverflow.com/questions/30031828/multiply-numpy-ndarray-with-1d-array-along-a-given-axis
     """
 
-    # ensure we're working with Numpy arrays
-    A = np.array(A)
-    B = np.array(B)
-
     # shape check
     if axis >= A.ndim:
         raise np.AxisError(axis, A.ndim)
@@ -129,6 +123,7 @@ class FactorAnalysisBase(BaseEstimator):
         relevance_factor=4.0,
         em_iterations=10,
         random_state=0,
+        enroll_iterations=1,
         ubm=None,
         ubm_kwargs=None,
         **kwargs,
@@ -138,6 +133,7 @@ class FactorAnalysisBase(BaseEstimator):
         self.ubm_kwargs = ubm_kwargs
         self.em_iterations = em_iterations
         self.random_state = random_state
+        self.enroll_iterations = enroll_iterations
 
         # axis 1 dimensions of U and V
         self.r_U = r_U
@@ -208,7 +204,7 @@ class FactorAnalysisBase(BaseEstimator):
 
         return len(unique_labels(y))
 
-    def initialize(self, X):
+    def initialize_using_array(self, X):
         """
         Accumulating 0th and 1st order statistics. Trains the UBM if needed.
 
@@ -238,14 +234,16 @@ class FactorAnalysisBase(BaseEstimator):
             logger.info("UBM means are None, training the UBM.")
             self.ubm.fit(X)
 
+    def initialize(self, ubm_projected_X, y, n_classes):
+        # Accumulating 0th and 1st order statistics
+        # https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/ISVTrainer.cpp#L68
+
+        logger.debug("Initializing the ISV/JFA using the UBM statistics.")
+
         # Initializing the state matrix
         if not hasattr(self, "_U") or not hasattr(self, "_D"):
             self.create_UVD()
 
-    def initialize_using_stats(self, ubm_projected_X, y, n_classes):
-        # Accumulating 0th and 1st order statistics
-        # https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/ISVTrainer.cpp#L68
-
         if is_input_dask_nested(ubm_projected_X):
             n_acc = [
                 dask.delayed(self._sum_n_statistics)(xx, yy, n_classes)
@@ -256,15 +254,14 @@ class FactorAnalysisBase(BaseEstimator):
                 dask.delayed(self._sum_f_statistics)(xx, yy, n_classes)
                 for xx, yy in zip(ubm_projected_X, y)
             ]
-            n_acc, f_acc = dask.compute(n_acc, f_acc)
-            n_acc = reduce_iadd(n_acc)
-            f_acc = reduce_iadd(f_acc)
+            n_acc, f_acc = reduce_iadd(n_acc, f_acc)
         else:
             # 0th order stats
             n_acc = self._sum_n_statistics(ubm_projected_X, y, n_classes)
             # 1st order stats
             f_acc = self._sum_f_statistics(ubm_projected_X, y, n_classes)
 
+        n_acc, f_acc = dask.compute(n_acc, f_acc)
         return n_acc, f_acc
 
     def create_UVD(self):
@@ -322,7 +319,7 @@ class FactorAnalysisBase(BaseEstimator):
 
         """
         # 0th order stats
-        n_acc = np.zeros((n_classes, self.ubm.n_gaussians))
+        n_acc = np.zeros((n_classes, self.ubm.n_gaussians), like=X[0].n)
 
         # Iterate for each client
         for x_i, y_i in zip(X, y):
@@ -358,7 +355,8 @@ class FactorAnalysisBase(BaseEstimator):
                 n_classes,
                 self.ubm.n_gaussians,
                 self.feature_dimension,
-            )
+            ),
+            like=X[0].sum_px,
         )
         # Iterate for each client
         for x_i, y_i in zip(X, y):
@@ -381,8 +379,8 @@ class FactorAnalysisBase(BaseEstimator):
             i: int
                 Class id to return the statistics for
         """
-        X = np.array(X)
-        return list(X[np.array(y) == i])
+        indices = np.where(np.array(y) == i)[0]
+        return [X[i] for i in indices]
 
     """
     Estimating U and x
@@ -505,7 +503,7 @@ class FactorAnalysisBase(BaseEstimator):
             latent_x = [None] * n_classes
             for y_i in unique_labels(y):
                 latent_x[y_i] = self._compute_latent_x_per_class(
-                    X_i=np.array(X)[y == y_i],
+                    X_i=self._get_statistics_by_class_id(X, y, y_i),
                     UProd=UProd,
                     UTinvSigma=UTinvSigma,
                     latent_y_i=latent_y[y_i] if latent_y is not None else None,
@@ -526,6 +524,9 @@ class FactorAnalysisBase(BaseEstimator):
                 x_i, latent_z_i=latent_z_i, latent_y_i=latent_y_i
             )
             latent_x_i.append(id_plus_prod_ih @ (UTinvSigma @ fn_x_ih))
+        # make sure latent_x_i stays a dask array by converting the list to a
+        # dask array explicitly
+        latent_x_i = np.vstack(latent_x_i)
         latent_x_i = np.swapaxes(latent_x_i, 0, 1)
         return latent_x_i
 
@@ -543,7 +544,6 @@ class FactorAnalysisBase(BaseEstimator):
                 Accumulated statistics for U_A2(n_gaussians* feature_dimension, r_U)
 
         """
-
         # Inverting A1 over the zero axis
         # https://stackoverflow.com/questions/11972102/is-there-a-way-to-efficiently-invert-an-array-of-matrices-with-numpy
         inv_A1 = np.linalg.inv(acc_U_A1)
@@ -626,8 +626,10 @@ class FactorAnalysisBase(BaseEstimator):
         """
 
         # U accumulators
-        acc_U_A1 = np.zeros((self.ubm.n_gaussians, self.r_U, self.r_U))
-        acc_U_A2 = np.zeros((self.supervector_dimension, self.r_U))
+        acc_U_A1 = np.zeros(
+            (self.ubm.n_gaussians, self.r_U, self.r_U), like=X[0].n
+        )
+        acc_U_A2 = np.zeros((self.supervector_dimension, self.r_U), like=X[0].n)
 
         # Loops over all people
         for y_i in set(y):
@@ -764,6 +766,8 @@ class FactorAnalysisBase(BaseEstimator):
 
         # m_cache_Fn_z_i = Fi - m_tmp_CD * (m + m_tmp_CD_b); // Fn_yi = sum_{sessions h}(N_{i,h}*(o_{i,h} - m - V*y_{i})
         fn_z_i = f_acc_i.flatten() - tmp_CD * (m + V_dot_v)
+        # convert fn_z_i to dask array here if required to make sure fn_z_i -= ... works
+        fn_z_i = np.array(fn_z_i, like=X_i[0].n)
 
         # Looping over the sessions
         for session_id in range(len(X_i)):
@@ -826,8 +830,8 @@ class FactorAnalysisBase(BaseEstimator):
 
         """
 
-        acc_D_A1 = np.zeros((self.supervector_dimension,))
-        acc_D_A2 = np.zeros((self.supervector_dimension,))
+        acc_D_A1 = np.zeros((self.supervector_dimension,), like=X[0].n)
+        acc_D_A2 = np.zeros((self.supervector_dimension,), like=X[0].n)
 
         # Precomputing
         # self._D.T / sigma
@@ -858,7 +862,7 @@ class FactorAnalysisBase(BaseEstimator):
 
         return acc_D_A1, acc_D_A2
 
-    def initialize_XYZ(self, n_samples_per_class):
+    def initialize_XYZ(self, n_samples_per_class, like=None):
         """
         Initialize E[x], E[y], E[z] state variables
 
@@ -873,21 +877,22 @@ class FactorAnalysisBase(BaseEstimator):
         latent_x = (n_classes, r_U, n_sessions)
 
         """
+        kw = dict(like=like) if isinstance(like, dask.array.core.Array) else {}
 
         # x (Eq. 36)
         # (n_classes, r_U,  n_samples )
         latent_x = []
         for n_s in n_samples_per_class:
-            latent_x.append(np.zeros((self.r_U, n_s)))
+            latent_x.append(np.zeros((self.r_U, n_s), **kw))
 
         n_classes = len(n_samples_per_class)
         latent_y = (
-            np.zeros((n_classes, self.r_V))
+            np.zeros((n_classes, self.r_V), **kw)
             if self.r_V and self.r_V > 0
             else None
         )
 
-        latent_z = np.zeros((n_classes, self.supervector_dimension))
+        latent_z = np.zeros((n_classes, self.supervector_dimension), **kw)
 
         return latent_x, latent_y, latent_z
 
@@ -1089,8 +1094,10 @@ class FactorAnalysisBase(BaseEstimator):
         """
 
         # U accumulators
-        acc_V_A1 = np.zeros((self.ubm.n_gaussians, self.r_V, self.r_V))
-        acc_V_A2 = np.zeros((self.supervector_dimension, self.r_V))
+        acc_V_A1 = np.zeros(
+            (self.ubm.n_gaussians, self.r_V, self.r_V), like=X[0].n
+        )
+        acc_V_A2 = np.zeros((self.supervector_dimension, self.r_V), like=X[0].n)
 
         # Loops over all people
         for i in set(y):
@@ -1240,7 +1247,7 @@ class FactorAnalysisBase(BaseEstimator):
 
         return fn_x.flatten()
 
-    def score(self, model, data):
+    def score_using_array(self, model, data):
         """
         Computes the ISV score using a numpy array as input
 
@@ -1259,18 +1266,23 @@ class FactorAnalysisBase(BaseEstimator):
 
         """
 
-        return self.score_using_stats(model, self.ubm.acc_stats(data))
+        return self.score(model, self.ubm.acc_stats(data))
 
-    def fit(self, X, y):
+    def fit_using_array(self, X, y):
+        """Fits the model using a numpy array or a dask array as input
+        The y matrix is computed to a numpy array immediately.
+        """
 
-        input_is_dask, X = check_and_persist_dask_input(X, persist=False)
+        input_is_dask, X = check_and_persist_dask_input(X, persist=True)
+        y = dask.compute(y)[0]
         y = np.squeeze(np.asarray(y))
         check_consistent_length(X, y)
 
-        self.initialize(X)
+        self.initialize_using_array(X)
 
         if input_is_dask:
             # split the X array based on the classes
+            # since the EM algorithm is only parallelizable per class
             X_new, y_new = [], []
             for class_id in unique_labels(y):
                 class_indices = y == class_id
@@ -1279,17 +1291,35 @@ class FactorAnalysisBase(BaseEstimator):
             X, y = X_new, y_new
             del X_new, y_new
 
-            stats = [
-                dask.delayed(self.ubm.stats_per_sample)(xx).persist()
-                for xx in X
-            ]
+            stats = [dask.delayed(self.ubm.transform)(xx).persist() for xx in X]
         else:
-            stats = self.ubm.stats_per_sample(X)
+            stats = self.ubm.transform(X)
 
+        logger.info("Computing statistics per sample")
         del X
-        self.fit_using_stats(stats, y)
+        self.fit(stats, y)
         return self
 
+    def enroll_using_array(self, X):
+        """
+        Enrolls a new client using a numpy array as input
+
+        Parameters
+        ----------
+        X : array
+            features to be enrolled
+
+        iterations : int
+            Number of iterations to perform
+
+        Returns
+        -------
+        self : object
+            z
+
+        """
+        return self.enroll([self.ubm.acc_stats(X)])
+
 
 class ISVMachine(FactorAnalysisBase):
     """
@@ -1349,7 +1379,7 @@ class ISVMachine(FactorAnalysisBase):
         # self.initialize_XYZ(y)
         UProd = self._compute_uprod()
         _, _, latent_z = self.initialize_XYZ(
-            n_samples_per_class=n_samples_per_class
+            n_samples_per_class=n_samples_per_class, like=X[0].n
         )
         latent_y = None
 
@@ -1390,14 +1420,13 @@ class ISVMachine(FactorAnalysisBase):
 
         """
         acc_U_A1 = [acc[0] for acc in acc_U_A1_acc_U_A2_list]
-        acc_U_A1 = reduce_iadd(acc_U_A1)
-
         acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list]
-        acc_U_A2 = reduce_iadd(acc_U_A2)
+
+        acc_U_A1, acc_U_A2 = reduce_iadd(acc_U_A1, acc_U_A2)
 
         return self.update_U(acc_U_A1, acc_U_A2)
 
-    def fit_using_stats(self, X, y):
+    def fit(self, X, y):
         """
         Trains the U matrix (session variability matrix)
 
@@ -1420,10 +1449,10 @@ class ISVMachine(FactorAnalysisBase):
             n_samples_per_class,
         ) = check_dask_input_samples_per_class(X, y)
 
-        n_acc, f_acc = self.initialize_using_stats(X, y, n_classes)
+        n_acc, f_acc = self.initialize(X, y, n_classes)
 
         for i in range(self.em_iterations):
-            logger.info("U Training: Iteration %d", i + 1)
+            logger.info("ISV U Training: Iteration %d", i + 1)
             if input_is_dask:
                 e_step_output = [
                     dask.delayed(self.e_step)(
@@ -1453,7 +1482,7 @@ class ISVMachine(FactorAnalysisBase):
         ubm_projected_X = self.ubm.acc_stats(X)
         return self.estimate_ux(ubm_projected_X)
 
-    def enroll_using_stats(self, X, iterations=1):
+    def enroll(self, X):
         """
         Enrolls a new client
         In ISV, the enrolment is defined as: :math:`m + Dz` with the latent variables `z`
@@ -1464,16 +1493,13 @@ class ISVMachine(FactorAnalysisBase):
         X : list of :py:class:`bob.learn.em.GMMStats`
             List of statistics to be enrolled
 
-
-        iterations : int
-            Number of iterations to perform
-
         Returns
         -------
         self : object
             z
 
         """
+        iterations = self.enroll_iterations
         # We have only one class for enrollment
         y = list(np.zeros(len(X), dtype=np.int32))
         n_acc = self._sum_n_statistics(X, y=y, n_classes=1)
@@ -1504,7 +1530,7 @@ class ISVMachine(FactorAnalysisBase):
 
         return latent_z
 
-    def enroll(self, X, iterations=1):
+    def enroll_using_array(self, X):
         """
         Enrolls a new client using a numpy array as input
 
@@ -1522,9 +1548,9 @@ class ISVMachine(FactorAnalysisBase):
             z
 
         """
-        return self.enroll_using_stats([self.ubm.acc_stats(X)], iterations)
+        return self.enroll([self.ubm.acc_stats(X)])
 
-    def score_using_stats(self, latent_z, data):
+    def score(self, latent_z, data):
         """
         Computes the ISV score
 
@@ -1651,7 +1677,7 @@ class JFAMachine(FactorAnalysisBase):
         VProd = self._compute_vprod()
 
         latent_x, latent_y, latent_z = self.initialize_XYZ(
-            n_samples_per_class=n_samples_per_class
+            n_samples_per_class=n_samples_per_class, like=X[0].n
         )
 
         # UPDATE Y, X AND FINALLY Z
@@ -1698,10 +1724,9 @@ class JFAMachine(FactorAnalysisBase):
 
         """
         acc_V_A1 = [acc[0] for acc in acc_V_A1_acc_V_A2_list]
-        acc_V_A1 = reduce_iadd(acc_V_A1)
-
         acc_V_A2 = [acc[1] for acc in acc_V_A1_acc_V_A2_list]
-        acc_V_A2 = reduce_iadd(acc_V_A2)
+
+        acc_V_A1, acc_V_A2 = reduce_iadd(acc_V_A1, acc_V_A2)
 
         # Inverting A1 over the zero axis
         # https://stackoverflow.com/questions/11972102/is-there-a-way-to-efficiently-invert-an-array-of-matrices-with-numpy
@@ -1749,7 +1774,7 @@ class JFAMachine(FactorAnalysisBase):
         VProd = self._compute_vprod()
 
         latent_x, latent_y, latent_z = self.initialize_XYZ(
-            n_samples_per_class=n_samples_per_class
+            n_samples_per_class=n_samples_per_class, like=X[0].n
         )
 
         # UPDATE Y, X AND FINALLY Z
@@ -1839,8 +1864,7 @@ class JFAMachine(FactorAnalysisBase):
         acc_U_A1 = [acc[0] for acc in acc_U_A1_acc_U_A2_list]
         acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list]
 
-        acc_U_A1 = reduce_iadd(acc_U_A1)
-        acc_U_A2 = reduce_iadd(acc_U_A2)
+        acc_U_A1, acc_U_A2 = reduce_iadd(acc_U_A1, acc_U_A2)
 
         return self.update_U(acc_U_A1, acc_U_A2)
 
@@ -1973,13 +1997,12 @@ class JFAMachine(FactorAnalysisBase):
         acc_D_A1 = [acc[0] for acc in acc_D_A1_acc_D_A2_list]
         acc_D_A2 = [acc[1] for acc in acc_D_A1_acc_D_A2_list]
 
-        acc_D_A1 = reduce_iadd(acc_D_A1)
-        acc_D_A2 = reduce_iadd(acc_D_A2)
+        acc_D_A1, acc_D_A2 = reduce_iadd(acc_D_A1, acc_D_A2)
 
         self._D = acc_D_A2 / acc_D_A1
         return self._D
 
-    def enroll_using_stats(self, X, iterations=1):
+    def enroll(self, X):
         """
         Enrolls a new client.
         In JFA the enrolment is defined as: :math:`m + Vy + Dz` with the latent variables `y` and `z`
@@ -1990,15 +2013,13 @@ class JFAMachine(FactorAnalysisBase):
         X : list of :py:class:`bob.learn.em.GMMStats`
             List of statistics
 
-        iterations : int
-            Number of iterations to perform
-
         Returns
         -------
         self : array
             z, y latent variables
 
         """
+        iterations = self.enroll_iterations
         # We have only one class for enrollment
         y = list(np.zeros(len(X), dtype=np.int32))
         n_acc = self._sum_n_statistics(X, y=y, n_classes=1)
@@ -2044,27 +2065,7 @@ class JFAMachine(FactorAnalysisBase):
         # The latent variables are wrapped in to 2axis arrays
         return latent_y[0], latent_z[0]
 
-    def enroll(self, X, iterations=1):
-        """
-        Enrolls a new client using a numpy array as input
-
-        Parameters
-        ----------
-        X : array
-            features to be enrolled
-
-        iterations : int
-            Number of iterations to perform
-
-        Returns
-        -------
-        self : object
-            z
-
-        """
-        return self.enroll_using_stats([self.ubm.acc_stats(X)], iterations)
-
-    def fit_using_stats(self, X, y):
+    def fit(self, X, y):
         """
         Trains the U matrix (session variability matrix)
 
@@ -2082,21 +2083,13 @@ class JFAMachine(FactorAnalysisBase):
 
         """
 
-        # In case those variables are already set
-        if (
-            not hasattr(self, "_U")
-            or not hasattr(self, "_V")
-            or not hasattr(self, "_D")
-        ):
-            self.create_UVD()
-
         (
             input_is_dask,
             n_classes,
             n_samples_per_class,
         ) = check_dask_input_samples_per_class(X, y)
 
-        n_acc, f_acc = self.initialize_using_stats(X, y, n_classes=n_classes)
+        n_acc, f_acc = self.initialize(X, y, n_classes=n_classes)
 
         # Updating V
         for i in range(self.em_iterations):
@@ -2191,7 +2184,7 @@ class JFAMachine(FactorAnalysisBase):
 
         return self
 
-    def score_using_stats(self, model, data):
+    def score(self, model, data):
         """
         Computes the JFA score
 
diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index d9ae87b4111e9bac008a95ee8d8869d003242fc8..6c5e3853f796ac7c2a635b7ad8c42b9011b16e6d 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -112,7 +112,7 @@ def e_step(data, machine):
     n_gaussians = len(machine.weights)
 
     # Allow the absence of previous statistics
-    statistics = GMMStats(n_gaussians, data.shape[-1])
+    statistics = GMMStats(n_gaussians, data.shape[-1], like=data)
 
     # Log weighted Gaussian likelihoods [array of shape (n_gaussians,n_samples)]
     log_weighted_likelihoods = log_weighted_likelihood(
@@ -128,18 +128,24 @@ def e_step(data, machine):
     # Accumulate
 
     # Total likelihood [float]
-    statistics.log_likelihood += log_likelihood.sum()
+    statistics.log_likelihood = log_likelihood.sum()
     # Count of samples [int]
-    statistics.t += data.shape[0]
+    statistics.t = data.shape[0]
     # Responsibilities [array of shape (n_gaussians,)]
-    statistics.n += responsibility.sum(axis=-1)
+    statistics.n = responsibility.sum(axis=-1)
+    sum_px, sum_pxx = [], []
     for i in range(n_gaussians):
         # p * x [array of shape (n_gaussians, n_samples, n_features)]
         px = responsibility[i, :, None] * data
         # First order stats [array of shape (n_gaussians, n_features)]
-        statistics.sum_px[i] += np.sum(px, axis=0)
+        # statistics.sum_px[i] = np.sum(px, axis=0)
+        sum_px.append(np.sum(px, axis=0))
         # Second order stats [array of shape (n_gaussians, n_features)]
-        statistics.sum_pxx[i] += np.sum(px * data, axis=0)
+        # statistics.sum_pxx[i] = np.sum(px * data, axis=0)
+        sum_pxx.append(np.sum(px * data, axis=0))
+
+    statistics.sum_px = np.vstack(sum_px)
+    statistics.sum_pxx = np.vstack(sum_pxx)
 
     return statistics
 
@@ -183,19 +189,22 @@ class GMMStats:
         Second order statistic
     """
 
-    def __init__(self, n_gaussians: int, n_features: int, **kwargs) -> None:
+    def __init__(
+        self, n_gaussians: int, n_features: int, like=None, **kwargs
+    ) -> None:
         super().__init__(**kwargs)
-
         self.n_gaussians = n_gaussians
         self.n_features = n_features
         self.log_likelihood = 0
         self.t = 0
-        self.n = np.zeros(shape=(self.n_gaussians,), dtype=float)
+        # create dask arrays if required
+        kw = dict(like=like) if like is not None else {}
+        self.n = np.zeros(shape=(self.n_gaussians,), dtype=float, **kw)
         self.sum_px = np.zeros(
-            shape=(self.n_gaussians, self.n_features), dtype=float
+            shape=(self.n_gaussians, self.n_features), dtype=float, **kw
         )
         self.sum_pxx = np.zeros(
-            shape=(self.n_gaussians, self.n_features), dtype=float
+            shape=(self.n_gaussians, self.n_features), dtype=float, **kw
         )
 
     def init_fields(
@@ -357,11 +366,10 @@ class GMMStats:
         """The number of gaussians and their dimensionality."""
         return (self.n_gaussians, self.n_features)
 
-    def compute(self):
-        for name in ("log_likelihood", "t"):
-            setattr(self, name, float(getattr(self, name)))
-        for name in ("n", "sum_px", "sum_pxx"):
-            setattr(self, name, np.asarray(getattr(self, name)))
+    @property
+    def nbytes(self):
+        """The number of bytes used by the statistics n, sum_px, sum_pxx."""
+        return self.n.nbytes + self.sum_px.nbytes + self.sum_pxx.nbytes
 
 
 class GMMMachine(BaseEstimator):
@@ -673,6 +681,9 @@ class GMMMachine(BaseEstimator):
         gaussians_group["variance_thresholds"] = self.variance_thresholds
 
     def __eq__(self, other):
+        if self._means is None:
+            return False
+
         return (
             np.allclose(self.means, other.means)
             and np.allclose(self.variances, other.variances)
@@ -862,17 +873,11 @@ class GMMMachine(BaseEstimator):
 
     def transform(self, X):
         """Returns the statistics for `X`."""
-        return self.acc_stats(X)
+        return self.stats_per_sample(X)
 
     def stats_per_sample(self, X):
         return [e_step(data=xx, machine=self) for xx in X]
 
-    def _more_tags(self):
-        return {
-            "stateless": False,
-            "requires_fit": True,
-        }
-
 
 def ml_gmm_m_step(
     machine: GMMMachine,
diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py
index 596a6a9fe2b75fa52b062e6286d4589e5ccff36d..8c9f8a343a1d1e4c71cc2a46b36fbeabc2fe280e 100644
--- a/bob/learn/em/test/test_factor_analysis.py
+++ b/bob/learn/em/test/test_factor_analysis.py
@@ -121,12 +121,12 @@ def test_JFATrainAndEnrol():
     ubm = GMMMachine(2, 3)
     ubm.means = UBM_MEAN.reshape((2, 3))
     ubm.variances = UBM_VAR.reshape((2, 3))
-    it = JFAMachine(2, 2, em_iterations=10, ubm=ubm)
+    it = JFAMachine(2, 2, em_iterations=10, enroll_iterations=5, ubm=ubm)
 
     it.U = copy.deepcopy(M_u)
     it.V = copy.deepcopy(M_v)
     it.D = copy.deepcopy(M_d)
-    it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
+    it.fit(TRAINING_STATS_X, TRAINING_STATS_y)
 
     v_ref = np.array(
         [
@@ -194,7 +194,7 @@ def test_JFATrainAndEnrol():
     gse2.sum_px = Fe[:, 1].reshape(2, 3)
 
     gse = [gse1, gse2]
-    latent_y, latent_z = it.enroll_using_stats(gse, 5)
+    latent_y, latent_z = it.enroll(gse)
 
     y_ref = np.array([0.555991469319657, 0.002773650670010], "float64")
     z_ref = np.array(
@@ -265,10 +265,11 @@ def test_ISVTrainAndEnrol():
         r_U=2,
         relevance_factor=4.0,
         em_iterations=10,
+        enroll_iterations=5,
     )
 
     it.U = copy.deepcopy(M_u)
-    it = it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
+    it = it.fit(TRAINING_STATS_X, TRAINING_STATS_y)
 
     np.testing.assert_allclose(it.D, d_ref, rtol=eps, atol=1e-8)
     np.testing.assert_allclose(it.U, u_ref, rtol=eps, atol=1e-8)
@@ -302,7 +303,7 @@ def test_ISVTrainAndEnrol():
     gse2.sum_px = Fe[:, 1].reshape(2, 3)
 
     gse = [gse1, gse2]
-    latent_z = it.enroll_using_stats(gse, 5)
+    latent_z = it.enroll(gse)
     np.testing.assert_allclose(latent_z, z_ref, rtol=eps, atol=1e-8)
 
 
@@ -321,13 +322,13 @@ def test_JFATrainInitialize():
     # first round
 
     n_classes = it.estimate_number_of_classes(TRAINING_STATS_y)
-    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes)
+    it.initialize(TRAINING_STATS_X, TRAINING_STATS_y, n_classes)
     u1 = it.U
     v1 = it.V
     d1 = it.D
 
     # second round
-    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes)
+    it.initialize(TRAINING_STATS_X, TRAINING_STATS_y, n_classes)
     u2 = it.U
     v2 = it.V
     d2 = it.D
@@ -352,12 +353,12 @@ def test_ISVTrainInitialize():
     # it.rng = rng
 
     n_classes = it.estimate_number_of_classes(TRAINING_STATS_y)
-    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes)
+    it.initialize(TRAINING_STATS_X, TRAINING_STATS_y, n_classes)
     u1 = copy.deepcopy(it.U)
     d1 = copy.deepcopy(it.D)
 
     # second round
-    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes)
+    it.initialize(TRAINING_STATS_X, TRAINING_STATS_y, n_classes)
     u2 = it.U
     d2 = it.D
 
@@ -397,14 +398,14 @@ def test_JFAMachine():
     model = [y, z]
 
     score_ref = -2.111577181208289
-    score = m.score_using_stats(model, gs)
+    score = m.score(model, gs)
     np.testing.assert_allclose(score, score_ref, atol=eps)
 
     # Scoring with numpy array
     np.random.seed(0)
     X = np.random.normal(loc=0.0, scale=1.0, size=(50, 3))
     score_ref = 2.028009315286946
-    score = m.score(model, X)
+    score = m.score_using_array(model, X)
     np.testing.assert_allclose(score, score_ref, atol=eps)
 
 
@@ -436,7 +437,7 @@ def test_ISVMachine():
 
     # Enrolled model
     latent_z = np.array([3, 4, 1, 2, 0, 1], "float64")
-    score = isv_machine.score_using_stats(latent_z, gs)
+    score = isv_machine.score(latent_z, gs)
     score_ref = -3.280498193082100
     np.testing.assert_allclose(score, score_ref, atol=eps)
 
@@ -444,7 +445,7 @@ def test_ISVMachine():
     np.random.seed(0)
     X = np.random.normal(loc=0.0, scale=1.0, size=(50, 3))
     score_ref = -1.2343813195374242
-    score = isv_machine.score(latent_z, X)
+    score = isv_machine.score_using_array(latent_z, X)
     np.testing.assert_allclose(score, score_ref, atol=eps)
 
 
@@ -573,7 +574,10 @@ def test_ISV_JFA_fit():
 
             err_msg = f"Test failed with prior={prior} and machine_type={machine_type} and transform={transform}"
             with multiprocess_dask_client():
-                machine.fit(data, labels)
+                machine.fit_using_array(data, labels)
+            print(
+                f"\nFinished training machine={machine_type} with prior={prior} and transform={transform}"
+            )
 
             arr = getattr(machine, test_attr)
             np.testing.assert_allclose(
diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py
index a8962637b7ff113f6df29abd83f4f82cb98c6c72..81793ef8bbd76a30c75e3074a784c7589781abd9 100644
--- a/bob/learn/em/test/test_gmm.py
+++ b/bob/learn/em/test/test_gmm.py
@@ -732,7 +732,7 @@ def test_ml_transformer():
         )
         np.testing.assert_almost_equal(machine.variances, expected_variances)
 
-        stats = machine.transform(test_data)
+        stats = machine.acc_stats(test_data)
 
         expected_stats = GMMStats(n_gaussians, n_features)
         expected_stats.init_fields(
@@ -780,7 +780,7 @@ def test_map_transformer():
         expected_weights = np.array([0.46226415, 0.53773585])
         np.testing.assert_almost_equal(machine.weights, expected_weights)
 
-        stats = machine.transform(test_data)
+        stats = machine.acc_stats(test_data)
 
         expected_stats = GMMStats(n_gaussians, n_features)
         expected_stats.init_fields(
diff --git a/doc/guide.rst b/doc/guide.rst
index c445a32d44d5ef38d5384d4bd0063caa1584c42c..d9ae724ab7e67cc0b7489bd2a4e96f9db05987bb 100644
--- a/doc/guide.rst
+++ b/doc/guide.rst
@@ -273,7 +273,7 @@ prior GMM.
     >>> # Training a GMM with 2 Gaussians of dimension 3
     >>> prior_gmm = bob.learn.em.GMMMachine(2).fit(data)
     >>> # Creating the container
-    >>> gmm_stats = prior_gmm.transform(data)
+    >>> gmm_stats = prior_gmm.acc_stats(data)
     >>> # Printing the responsibilities
     >>> print(gmm_stats.n/gmm_stats.t)
      [0.6  0.4]
@@ -331,19 +331,21 @@ The snippet bellow shows how to:
    >>> y = np.hstack((np.zeros(10, dtype=int), np.ones(10, dtype=int)))
    >>> # Create an ISV machine with a UBM of 2 gaussians
    >>> isv_machine = bob.learn.em.ISVMachine(r_U=2, ubm_kwargs=dict(n_gaussians=2))
-   >>> _ = isv_machine.fit(X, y)  # DOCTEST: +SKIP_
+   >>> _ = isv_machine.fit_using_array(X, y)  # DOCTEST: +SKIP_
+   >>> # Alternatively, you can create a pipeline of a GMMMachine and an ISVMachine
+   >>> # and call pipeline.fit(X, y) instead of calling isv.fit_using_array(X, y)
    >>> isv_machine.U
      array(...)
 
    >>> # Enrolling a subject
    >>> enroll_data = np.array([[1.2, 0.1, 1.4], [0.5, 0.2, 0.3]])
-   >>> model = isv_machine.enroll(enroll_data)
+   >>> model = isv_machine.enroll_using_array(enroll_data)
    >>> print(model)
      [[ 0.54   0.246  0.505  1.617 -0.791  0.746]]
 
    >>> # Probing
    >>> probe_data = np.array([[1.2, 0.1, 1.4], [0.5, 0.2, 0.3]])
-   >>> score = isv_machine.score(model, probe_data)
+   >>> score = isv_machine.score_using_array(model, probe_data)
    >>> print(score)
      [2.754]
 
@@ -394,17 +396,17 @@ such session variability model.
    >>> y = np.hstack((np.zeros(10, dtype=int), np.ones(10, dtype=int)))
    >>> # Create a JFA machine with a UBM of 2 gaussians
    >>> jfa_machine = bob.learn.em.JFAMachine(r_U=2, r_V=2, ubm_kwargs=dict(n_gaussians=2))
-   >>> _ = jfa_machine.fit(X, y)
+   >>> _ = jfa_machine.fit_using_array(X, y)
    >>> jfa_machine.U
      array(...)
 
    >>> enroll_data = np.array([[1.2, 0.1, 1.4], [0.5, 0.2, 0.3]])
-   >>> model = jfa_machine.enroll(enroll_data)
+   >>> model = jfa_machine.enroll_using_array(enroll_data)
    >>> print(model)
      (array([0.634, 0.165]), array([ 0.,  0.,  0.,  0., -0.,  0.]))
 
    >>> probe_data = np.array([[1.2, 0.1, 1.4], [0.5, 0.2, 0.3]])
-   >>> score = jfa_machine.score(model, probe_data)
+   >>> score = jfa_machine.score_using_array(model, probe_data)
    >>> print(score)
      [0.471]
 
diff --git a/doc/plot/plot_ISV.py b/doc/plot/plot_ISV.py
index f3a9c39617ddfbce6970078fc45dae99e66b750a..55bac550003ff5f33255f6bf288f82ea6a3d3190 100644
--- a/doc/plot/plot_ISV.py
+++ b/doc/plot/plot_ISV.py
@@ -46,7 +46,7 @@ isv_machine.U = np.array(
     [[-0.150035, -0.44441, -1.67812, 2.47621, -0.52885, 0.659141]]
 ).T
 
-isv_machine = isv_machine.fit(X, y)
+isv_machine = isv_machine.fit_using_array(X, y)
 
 # Variability direction
 u0 = isv_machine.U[0:2, 0] / np.linalg.norm(isv_machine.U[0:2, 0])
diff --git a/doc/plot/plot_JFA.py b/doc/plot/plot_JFA.py
index a56ab32f584509ea1a50e9852df8a5386b4f1fea..799001a4e8be99de39a22f76775585c888922c44 100644
--- a/doc/plot/plot_JFA.py
+++ b/doc/plot/plot_JFA.py
@@ -56,7 +56,7 @@ jfa_machine.Y = np.array(
 jfa_machine.D = np.array(
     [0.732467, 0.281321, 0.543212, -0.512974, 1.04108, 0.835224]
 )
-jfa_machine = jfa_machine.fit(X, y)
+jfa_machine = jfa_machine.fit_using_array(X, y)
 
 
 # Variability direction U