diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index ac850dd88623271d556ea70764837ed3d080c7e7..ddf985321c3bffbe298218e23512d962a5d2d98a 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -35,23 +35,38 @@ def logaddexp_reduce(array, axis=0, keepdims=False):
     )
 
 
-def e_step(data, weights, means, variances, g_norms, log_weights):
-    # Ensure data is a series of samples (2D array)
-    data = np.atleast_2d(data)
+def log_weighted_likelihood(data, means, variances, g_norms, log_weights):
+    """Returns the weighted log likelihood for each Gaussian for a set of data.
 
-    n_gaussians = len(weights)
-
-    # Allow the absence of previous statistics
-    statistics = GMMStats(n_gaussians, data.shape[-1])
-
-    # Log weighted Gaussian likelihoods [array of shape (n_gaussians,n_samples)]
-    z = np.empty_like(data, shape=(n_gaussians, len(data)))
+    Parameters
+    ----------
+    data
+        Data to compute the log likelihood on.
+    means
+        Means of the Gaussians.
+    variances
+        Variances of the Gaussians.
+    g_norms
+        Normalization factors of the Gaussians.
+    log_weights
+        Log weights of the Gaussians.
+
+    Returns
+    -------
+    array of shape (n_gaussians, n_samples)
+        The weighted log likelihood of each sample of each Gaussian.
+    """
+    # Compute the likelihood for each data point on each Gaussian
+    n_gaussians, n_samples = len(means), len(data)
+    z = np.empty(shape=(n_gaussians, n_samples), like=data)
     for i in range(n_gaussians):
         z[i] = np.sum((data - means[i]) ** 2 / variances[i], axis=-1)
     ll = -0.5 * (g_norms[:, None] + z)
     log_weighted_likelihoods = log_weights[:, None] + ll
+    return log_weighted_likelihoods
 
-    # Log likelihood [array of shape (n_samples,)]
+
+def reduce_loglikelihood(log_weighted_likelihoods):
     if isinstance(log_weighted_likelihoods, np.ndarray):
         log_likelihood = logaddexp_reduce(log_weighted_likelihoods)
     else:
@@ -64,6 +79,54 @@ def e_step(data, weights, means, variances, g_norms, log_weights):
             dtype=float,
             keepdims=False,
         )
+    return log_likelihood
+
+
+def log_likelihood(data, means, variances, g_norms, log_weights):
+    """Returns the current log likelihood for a set of data in this Machine.
+
+    Parameters
+    ----------
+    data
+        Data to compute the log likelihood on.
+
+    Returns
+    -------
+    array of shape (n_samples)
+        The log likelihood of each sample.
+    """
+    data = np.atleast_2d(data)
+
+    # All likelihoods [array of shape (n_gaussians, n_samples)]
+    log_weighted_likelihoods = log_weighted_likelihood(
+        data=data,
+        means=means,
+        variances=variances,
+        g_norms=g_norms,
+        log_weights=log_weights,
+    )
+    # Likelihoods of each sample on this machine. [array of shape (n_samples,)]
+    ll_reduced = reduce_loglikelihood(log_weighted_likelihoods)
+    return ll_reduced
+
+
+def e_step(data, weights, means, variances, g_norms, log_weights):
+    """Expectation step of the e-m algorithm."""
+    # Ensure data is a series of samples (2D array)
+    data = np.atleast_2d(data)
+
+    n_gaussians = len(weights)
+
+    # Allow the absence of previous statistics
+    statistics = GMMStats(n_gaussians, data.shape[-1])
+
+    # Log weighted Gaussian likelihoods [array of shape (n_gaussians,n_samples)]
+    log_weighted_likelihoods = log_weighted_likelihood(
+        data, means, variances, g_norms, log_weights
+    )
+
+    # Log likelihood [array of shape (n_samples,)]
+    log_likelihood = reduce_loglikelihood(log_weighted_likelihoods)
 
     # Responsibility P [array of shape (n_gaussians, n_samples)]
     responsibility = np.exp(log_weighted_likelihoods - log_likelihood[None, :])
@@ -86,11 +149,6 @@ def e_step(data, weights, means, variances, g_norms, log_weights):
             px * data, axis=0
         )
 
-    # px = np.multiply(responsibility[:, :, None], data[None, :, :])
-    # statistics.sum_px = statistics.sum_px + px.sum(axis=1)
-    # pxx = np.multiply(px[:, :, :], data[None, :, :])
-    # statistics.sum_pxx = statistics.sum_pxx + pxx.sum(axis=1)
-
     return statistics
 
 
@@ -105,10 +163,11 @@ def m_step(
     map_alpha,
     trainer,
 ):
+    """Maximization step of the e-m algorithm."""
     m_step_func = map_gmm_m_step if trainer == "map" else ml_gmm_m_step
     statistics = functools.reduce(operator.iadd, statistics)
     m_step_func(
-        machine,
+        machine=machine,
         statistics=statistics,
         update_means=update_means,
         update_variances=update_variances,
@@ -701,14 +760,13 @@ class GMMMachine(BaseEstimator):
         array of shape (n_gaussians, n_samples)
             The weighted log likelihood of each sample of each Gaussian.
         """
-        # Compute the likelihood for each data point on each Gaussian
-        z = (
-            (data[None, ..., :] - self.means[..., None, :]) ** 2
-            / self.variances[..., None, :]
-        ).sum(axis=-1)
-        ll = -0.5 * (self.g_norms[:, None] + z)
-        log_weighted_likelihood = self.log_weights[:, None] + ll
-        return log_weighted_likelihood
+        return log_weighted_likelihood(
+            data=data,
+            means=self.means,
+            variances=self.variances,
+            g_norms=self.g_norms,
+            log_weights=self.log_weights,
+        )
 
     def log_likelihood(
         self,
@@ -726,112 +784,12 @@ class GMMMachine(BaseEstimator):
         array of shape (n_samples)
             The log likelihood of each sample.
         """
-        if data.ndim == 1:
-            data = data.reshape((1, -1))
-
-        # All likelihoods [array of shape (n_gaussians, n_samples)]
-        log_weighted_likelihood = self.log_weighted_likelihood(data)
-
-        def logaddexp_reduce(array, axis=0, keepdims=False):
-            return np.logaddexp.reduce(
-                array, axis=axis, keepdims=keepdims, initial=-np.inf
-            )
-
-        if isinstance(log_weighted_likelihood, np.ndarray):
-            ll_reduced = logaddexp_reduce(log_weighted_likelihood)
-        else:
-            # Sum along gaussians axis (using logAddExp to prevent underflow)
-            ll_reduced = da.reduction(
-                x=log_weighted_likelihood,
-                chunk=logaddexp_reduce,
-                aggregate=logaddexp_reduce,
-                axis=0,
-                dtype=float,
-                keepdims=False,
-            )
-        return ll_reduced
-
-        # Likelihoods of each sample on this machine. [array of shape (n_samples,)]
-
-    def acc_statistics(
-        self,
-        data: "np.ndarray[('n_samples', 'n_features'), float]",  # noqa: F821
-        statistics: Union[GMMStats, None] = None,
-    ):
-        """Accumulates the statistics of GMMStats for a set of data.
-
-        This can be used to compute a GMM step in parallel: each worker/thread applies
-        the e-step of a copy of the same GMM on part of the training data, and the
-        resulting `GMMStats` object of each worker is summed before applying the m-step.
-
-        Parameters
-        ----------
-        data:
-            The data to extract the statistics on the GMM.
-        statistics:
-            A GMMStats object that will accumulate the previous and current stats.
-            Values are modified in-place AND returned. (or only returned if
-            `statistics` is None)
-        """
-        # Ensure data is a series of samples (2D array)
-        if data.ndim == 1:
-            data = data.reshape(shape=(1, -1))
-
-        # Allow the absence of previous statistics
-        if statistics is None:
-            statistics = GMMStats(self.n_gaussians, data.shape[-1])
-
-        # Log weighted Gaussian likelihoods [array of shape (n_gaussians,n_samples)]
-        log_weighted_likelihoods = self.log_weighted_likelihood(data)
-        # Log likelihood [array of shape (n_samples,)]
-        log_likelihood = self.log_likelihood(data)
-        # Responsibility P [array of shape (n_gaussians, n_samples)]
-        responsibility = np.exp(
-            log_weighted_likelihoods - log_likelihood[None, :]
-        )
-
-        # Accumulate
-
-        # Total likelihood [float]
-        statistics.log_likelihood += log_likelihood.sum()
-        # Count of samples [int]
-        statistics.t += data.shape[0]
-        # Responsibilities [array of shape (n_gaussians,)]
-        statistics.n = statistics.n + responsibility.sum(axis=-1)
-        # p * x [array of shape (n_gaussians, n_samples, n_features)]
-        px = np.multiply(responsibility[:, :, None], data[None, :, :])
-        # First order stats [array of shape (n_gaussians, n_features)]
-        statistics.sum_px = statistics.sum_px + px.sum(axis=1)
-        # Second order stats [array of shape (n_gaussians, n_features)]
-        pxx = np.multiply(px[:, :, :], data[None, :, :])
-        statistics.sum_pxx = statistics.sum_pxx + pxx.sum(axis=1)
-
-        return statistics
-
-    def e_step(
-        self,
-        data: "np.ndarray[('n_samples', 'n_features'), float]",  # noqa: F821
-    ):  # noqa: F821
-        """Expectation step of the e-m algorithm."""
-        return self.acc_statistics(data)
-
-    def m_step(
-        self,
-        stats: GMMStats,
-        **kwargs,
-    ):
-        """Maximization step of the e-m algorithm."""
-        self.m_step_func(
-            self,
-            statistics=stats,
-            update_means=self.update_means,
-            update_variances=self.update_variances,
-            update_weights=self.update_weights,
-            mean_var_update_threshold=self.mean_var_update_threshold,
-            reynolds_adaptation=self.map_relevance_factor is not None,
-            alpha=self.map_alpha,
-            relevance_factor=self.map_relevance_factor,
-            **kwargs,
+        return log_likelihood(
+            data=data,
+            means=self.means,
+            variances=self.variances,
+            g_norms=self.g_norms,
+            log_weights=self.log_weights,
         )
 
     def fit(self, X, y=None):
@@ -934,21 +892,18 @@ class GMMMachine(BaseEstimator):
             logger.info(
                 "Reached maximum step. Training stopped without convergence."
             )
-        self.compute()
-        return self
-
-    def fit_partial(self, X, y=None, **kwargs):
-        """Applies one iteration of GMM training."""
-        if self._means is None:
-            self.initialize_gaussians(X)
-
-        stats = self.e_step(X)
-        self.m_step(stats=stats)
         return self
 
     def transform(self, X, **kwargs):
         """Returns the statistics for `X`."""
-        return self.e_step(X)
+        return e_step(
+            X,
+            self.weights,
+            self.means,
+            self.variances,
+            self.g_norms,
+            self.log_weights,
+        )
 
     def _more_tags(self):
         return {
@@ -956,10 +911,6 @@ class GMMMachine(BaseEstimator):
             "requires_fit": True,
         }
 
-    def compute(self, *args, **kwargs):
-        for name in ("weights", "means", "variances"):
-            setattr(self, name, np.asarray(getattr(self, name)))
-
 
 def ml_gmm_m_step(
     machine: GMMMachine,
diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py
index b62da1abce52ea7cb96e1aa26f55ccee35a2d373..65743ce3b95ff9bfa6a5712800a023942316c22f 100644
--- a/bob/learn/em/test/test_gmm.py
+++ b/bob/learn/em/test/test_gmm.py
@@ -22,6 +22,7 @@ from h5py import File as HDF5File
 from pkg_resources import resource_filename
 
 from bob.learn.em import GMMMachine, GMMStats, KMeansMachine
+from bob.learn.em import gmm as gmm_module
 
 from .test_kmeans import to_dask_array, to_numpy
 
@@ -283,7 +284,14 @@ def test_GMMMachine_stats():
     gmm.variances = np.array([[1, 10], [2, 5]], "float64")
     gmm.variance_thresholds = np.array([[0, 0], [0, 0]], "float64")
 
-    stats = gmm.acc_statistics(arrayset)
+    stats = gmm_module.e_step(
+        arrayset,
+        gmm.weights,
+        gmm.means,
+        gmm.variances,
+        gmm.g_norms,
+        gmm.log_weights,
+    )
 
     stats_ref = GMMStats(n_gaussians=2, n_features=2)
     stats_ref.load(
@@ -355,7 +363,14 @@ def test_GMMStats_operations():
     machine.variances = np.ones_like(machine.means)
 
     # Populate the GMMStats
-    stats = machine.acc_statistics(data)
+    stats = gmm_module.e_step(
+        data,
+        machine.weights,
+        machine.means,
+        machine.variances,
+        machine.g_norms,
+        machine.log_weights,
+    )
 
     # Check shapes
     assert stats.n.shape == (n_gaussians,), stats.n.shape
@@ -601,8 +616,25 @@ def test_ml_em():
     machine.means = np.repeat([[2], [8]], n_features, 1)
     machine.variances = np.ones_like(machine.means)
 
-    stats = machine.e_step(data)
-    machine.m_step(stats)
+    stats = gmm_module.e_step(
+        data,
+        machine.weights,
+        machine.means,
+        machine.variances,
+        machine.g_norms,
+        machine.log_weights,
+    )
+    gmm_module.m_step(
+        machine,
+        [stats],
+        machine.update_means,
+        machine.update_variances,
+        machine.update_weights,
+        machine.mean_var_update_threshold,
+        machine.map_relevance_factor,
+        machine.map_alpha,
+        machine.trainer,
+    )
 
     expected_means = np.array([[1.5, 1.5, 2.0], [7.0, 8.0, 8.0]])
     np.testing.assert_almost_equal(machine.means, expected_means)
@@ -640,8 +672,25 @@ def test_map_em():
     np.testing.assert_equal(machine.variances, prior_machine.variances)
     np.testing.assert_equal(machine.weights, prior_machine.weights)
 
-    stats = machine.e_step(post_data)
-    machine.m_step(stats)
+    stats = gmm_module.e_step(
+        post_data,
+        machine.weights,
+        machine.means,
+        machine.variances,
+        machine.g_norms,
+        machine.log_weights,
+    )
+    gmm_module.m_step(
+        machine,
+        [stats],
+        machine.update_means,
+        machine.update_variances,
+        machine.update_weights,
+        machine.mean_var_update_threshold,
+        machine.map_relevance_factor,
+        machine.map_alpha,
+        machine.trainer,
+    )
 
     expected_means = np.array(
         [[1.83333333, 1.83333333, 2.0], [7.57142857, 8, 8]]