diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index ddf985321c3bffbe298218e23512d962a5d2d98a..762939981b58a90379ceac1d7b0e4073ba8d3f5d 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -35,21 +35,15 @@ def logaddexp_reduce(array, axis=0, keepdims=False):
     )
 
 
-def log_weighted_likelihood(data, means, variances, g_norms, log_weights):
+def log_weighted_likelihood(data, machine):
     """Returns the weighted log likelihood for each Gaussian for a set of data.
 
     Parameters
     ----------
     data
         Data to compute the log likelihood on.
-    means
-        Means of the Gaussians.
-    variances
-        Variances of the Gaussians.
-    g_norms
-        Normalization factors of the Gaussians.
-    log_weights
-        Log weights of the Gaussians.
+    machine
+        The GMM machine.
 
     Returns
     -------
@@ -57,12 +51,14 @@ def log_weighted_likelihood(data, means, variances, g_norms, log_weights):
         The weighted log likelihood of each sample of each Gaussian.
     """
     # Compute the likelihood for each data point on each Gaussian
-    n_gaussians, n_samples = len(means), len(data)
+    n_gaussians, n_samples = len(machine.means), len(data)
     z = np.empty(shape=(n_gaussians, n_samples), like=data)
     for i in range(n_gaussians):
-        z[i] = np.sum((data - means[i]) ** 2 / variances[i], axis=-1)
-    ll = -0.5 * (g_norms[:, None] + z)
-    log_weighted_likelihoods = log_weights[:, None] + ll
+        z[i] = np.sum(
+            (data - machine.means[i]) ** 2 / machine.variances[i], axis=-1
+        )
+    ll = -0.5 * (machine.g_norms[:, None] + z)
+    log_weighted_likelihoods = machine.log_weights[:, None] + ll
     return log_weighted_likelihoods
 
 
@@ -82,13 +78,15 @@ def reduce_loglikelihood(log_weighted_likelihoods):
     return log_likelihood
 
 
-def log_likelihood(data, means, variances, g_norms, log_weights):
+def log_likelihood(data, machine):
     """Returns the current log likelihood for a set of data in this Machine.
 
     Parameters
     ----------
     data
         Data to compute the log likelihood on.
+    machine
+        The GMM machine.
 
     Returns
     -------
@@ -100,29 +98,26 @@ def log_likelihood(data, means, variances, g_norms, log_weights):
     # All likelihoods [array of shape (n_gaussians, n_samples)]
     log_weighted_likelihoods = log_weighted_likelihood(
         data=data,
-        means=means,
-        variances=variances,
-        g_norms=g_norms,
-        log_weights=log_weights,
+        machine=machine,
     )
     # Likelihoods of each sample on this machine. [array of shape (n_samples,)]
     ll_reduced = reduce_loglikelihood(log_weighted_likelihoods)
     return ll_reduced
 
 
-def e_step(data, weights, means, variances, g_norms, log_weights):
+def e_step(data, machine):
     """Expectation step of the e-m algorithm."""
     # Ensure data is a series of samples (2D array)
     data = np.atleast_2d(data)
 
-    n_gaussians = len(weights)
+    n_gaussians = len(machine.weights)
 
     # Allow the absence of previous statistics
     statistics = GMMStats(n_gaussians, data.shape[-1])
 
     # Log weighted Gaussian likelihoods [array of shape (n_gaussians,n_samples)]
     log_weighted_likelihoods = log_weighted_likelihood(
-        data, means, variances, g_norms, log_weights
+        data=data, machine=machine
     )
 
     # Log likelihood [array of shape (n_samples,)]
@@ -153,29 +148,22 @@ def e_step(data, weights, means, variances, g_norms, log_weights):
 
 
 def m_step(
-    machine,
     statistics,
-    update_means,
-    update_variances,
-    update_weights,
-    mean_var_update_threshold,
-    map_relevance_factor,
-    map_alpha,
-    trainer,
+    machine,
 ):
     """Maximization step of the e-m algorithm."""
-    m_step_func = map_gmm_m_step if trainer == "map" else ml_gmm_m_step
+    m_step_func = map_gmm_m_step if machine.trainer == "map" else ml_gmm_m_step
     statistics = functools.reduce(operator.iadd, statistics)
     m_step_func(
         machine=machine,
         statistics=statistics,
-        update_means=update_means,
-        update_variances=update_variances,
-        update_weights=update_weights,
-        mean_var_update_threshold=mean_var_update_threshold,
-        reynolds_adaptation=map_relevance_factor is not None,
-        alpha=map_alpha,
-        relevance_factor=map_relevance_factor,
+        update_means=machine.update_means,
+        update_variances=machine.update_variances,
+        update_weights=machine.update_weights,
+        mean_var_update_threshold=machine.mean_var_update_threshold,
+        reynolds_adaptation=machine.map_relevance_factor is not None,
+        alpha=machine.map_alpha,
+        relevance_factor=machine.map_relevance_factor,
     )
     average_output = float(statistics.log_likelihood / statistics.t)
     return machine, average_output
@@ -762,10 +750,7 @@ class GMMMachine(BaseEstimator):
         """
         return log_weighted_likelihood(
             data=data,
-            means=self.means,
-            variances=self.variances,
-            g_norms=self.g_norms,
-            log_weights=self.log_weights,
+            machine=self,
         )
 
     def log_likelihood(
@@ -786,10 +771,7 @@ class GMMMachine(BaseEstimator):
         """
         return log_likelihood(
             data=data,
-            means=self.means,
-            variances=self.variances,
-            g_norms=self.g_norms,
-            log_weights=self.log_weights,
+            machine=self,
         )
 
     def fit(self, X, y=None):
@@ -808,17 +790,6 @@ class GMMMachine(BaseEstimator):
             )
             self.variances = np.ones_like(self.means)
 
-        m_step_func = functools.partial(
-            m_step,
-            update_means=self.update_means,
-            update_variances=self.update_variances,
-            update_weights=self.update_weights,
-            mean_var_update_threshold=self.mean_var_update_threshold,
-            map_relevance_factor=self.map_relevance_factor,
-            map_alpha=self.map_alpha,
-            trainer=self.trainer,
-        )
-
         X = array_to_delayed_list(X, input_is_dask)
 
         average_output = 0
@@ -842,16 +813,12 @@ class GMMMachine(BaseEstimator):
                 stats = [
                     dask.delayed(e_step)(
                         data=xx,
-                        weights=self.weights,
-                        means=self.means,
-                        variances=self.variances,
-                        g_norms=self.g_norms,
-                        log_weights=self.log_weights,
+                        machine=self,
                     )
                     for xx in X
                 ]
                 new_machine, average_output = dask.compute(
-                    dask.delayed(m_step_func)(self, stats)
+                    dask.delayed(m_step)(stats, self)
                 )[0]
                 for attr in ["weights", "means", "variances"]:
                     setattr(self, attr, getattr(new_machine, attr))
@@ -859,14 +826,10 @@ class GMMMachine(BaseEstimator):
                 stats = [
                     e_step(
                         data=X,
-                        weights=self.weights,
-                        means=self.means,
-                        variances=self.variances,
-                        g_norms=self.g_norms,
-                        log_weights=self.log_weights,
+                        machine=self,
                     )
                 ]
-                _, average_output = m_step_func(self, stats)
+                _, average_output = m_step(stats, self)
 
             logger.debug(f"log likelihood = {average_output}")
             if step > 1:
@@ -897,12 +860,8 @@ class GMMMachine(BaseEstimator):
     def transform(self, X, **kwargs):
         """Returns the statistics for `X`."""
         return e_step(
-            X,
-            self.weights,
-            self.means,
-            self.variances,
-            self.g_norms,
-            self.log_weights,
+            data=X,
+            machine=self,
         )
 
     def _more_tags(self):
diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py
index 65743ce3b95ff9bfa6a5712800a023942316c22f..6e4b5d0aea309140a081fbf91a2b92b9bc732935 100644
--- a/bob/learn/em/test/test_gmm.py
+++ b/bob/learn/em/test/test_gmm.py
@@ -286,11 +286,7 @@ def test_GMMMachine_stats():
 
     stats = gmm_module.e_step(
         arrayset,
-        gmm.weights,
-        gmm.means,
-        gmm.variances,
-        gmm.g_norms,
-        gmm.log_weights,
+        gmm,
     )
 
     stats_ref = GMMStats(n_gaussians=2, n_features=2)
@@ -365,11 +361,7 @@ def test_GMMStats_operations():
     # Populate the GMMStats
     stats = gmm_module.e_step(
         data,
-        machine.weights,
-        machine.means,
-        machine.variances,
-        machine.g_norms,
-        machine.log_weights,
+        machine,
     )
 
     # Check shapes
@@ -618,22 +610,11 @@ def test_ml_em():
 
     stats = gmm_module.e_step(
         data,
-        machine.weights,
-        machine.means,
-        machine.variances,
-        machine.g_norms,
-        machine.log_weights,
+        machine,
     )
     gmm_module.m_step(
-        machine,
         [stats],
-        machine.update_means,
-        machine.update_variances,
-        machine.update_weights,
-        machine.mean_var_update_threshold,
-        machine.map_relevance_factor,
-        machine.map_alpha,
-        machine.trainer,
+        machine,
     )
 
     expected_means = np.array([[1.5, 1.5, 2.0], [7.0, 8.0, 8.0]])
@@ -674,22 +655,11 @@ def test_map_em():
 
     stats = gmm_module.e_step(
         post_data,
-        machine.weights,
-        machine.means,
-        machine.variances,
-        machine.g_norms,
-        machine.log_weights,
+        machine,
     )
     gmm_module.m_step(
-        machine,
         [stats],
-        machine.update_means,
-        machine.update_variances,
-        machine.update_weights,
-        machine.mean_var_update_threshold,
-        machine.map_relevance_factor,
-        machine.map_alpha,
-        machine.trainer,
+        machine,
     )
 
     expected_means = np.array(