diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index ddf985321c3bffbe298218e23512d962a5d2d98a..38908cc7feee4ab3fa9d6f28e853534cd0c59357 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -35,21 +35,15 @@ def logaddexp_reduce(array, axis=0, keepdims=False):
     )
 
 
-def log_weighted_likelihood(data, means, variances, g_norms, log_weights):
+def log_weighted_likelihood(data, machine):
     """Returns the weighted log likelihood for each Gaussian for a set of data.
 
     Parameters
     ----------
     data
         Data to compute the log likelihood on.
-    means
-        Means of the Gaussians.
-    variances
-        Variances of the Gaussians.
-    g_norms
-        Normalization factors of the Gaussians.
-    log_weights
-        Log weights of the Gaussians.
+    machine
+        The GMM machine.
 
     Returns
     -------
@@ -57,12 +51,14 @@ def log_weighted_likelihood(data, means, variances, g_norms, log_weights):
         The weighted log likelihood of each sample of each Gaussian.
     """
     # Compute the likelihood for each data point on each Gaussian
-    n_gaussians, n_samples = len(means), len(data)
+    n_gaussians, n_samples = len(machine.means), len(data)
     z = np.empty(shape=(n_gaussians, n_samples), like=data)
     for i in range(n_gaussians):
-        z[i] = np.sum((data - means[i]) ** 2 / variances[i], axis=-1)
-    ll = -0.5 * (g_norms[:, None] + z)
-    log_weighted_likelihoods = log_weights[:, None] + ll
+        z[i] = np.sum(
+            (data - machine.means[i]) ** 2 / machine.variances[i], axis=-1
+        )
+    ll = -0.5 * (machine.g_norms[:, None] + z)
+    log_weighted_likelihoods = machine.log_weights[:, None] + ll
     return log_weighted_likelihoods
 
 
@@ -82,13 +78,15 @@ def reduce_loglikelihood(log_weighted_likelihoods):
     return log_likelihood
 
 
-def log_likelihood(data, means, variances, g_norms, log_weights):
+def log_likelihood(data, machine):
     """Returns the current log likelihood for a set of data in this Machine.
 
     Parameters
     ----------
     data
         Data to compute the log likelihood on.
+    machine
+        The GMM machine.
 
     Returns
     -------
@@ -100,29 +98,26 @@ def log_likelihood(data, means, variances, g_norms, log_weights):
     # All likelihoods [array of shape (n_gaussians, n_samples)]
     log_weighted_likelihoods = log_weighted_likelihood(
         data=data,
-        means=means,
-        variances=variances,
-        g_norms=g_norms,
-        log_weights=log_weights,
+        machine=machine,
     )
     # Likelihoods of each sample on this machine. [array of shape (n_samples,)]
     ll_reduced = reduce_loglikelihood(log_weighted_likelihoods)
     return ll_reduced
 
 
-def e_step(data, weights, means, variances, g_norms, log_weights):
+def e_step(data, machine):
     """Expectation step of the e-m algorithm."""
     # Ensure data is a series of samples (2D array)
     data = np.atleast_2d(data)
 
-    n_gaussians = len(weights)
+    n_gaussians = len(machine.weights)
 
     # Allow the absence of previous statistics
     statistics = GMMStats(n_gaussians, data.shape[-1])
 
     # Log weighted Gaussian likelihoods [array of shape (n_gaussians,n_samples)]
     log_weighted_likelihoods = log_weighted_likelihood(
-        data, means, variances, g_norms, log_weights
+        data=data, machine=machine
     )
 
     # Log likelihood [array of shape (n_samples,)]
@@ -153,29 +148,22 @@ def e_step(data, weights, means, variances, g_norms, log_weights):
 
 
 def m_step(
-    machine,
     statistics,
-    update_means,
-    update_variances,
-    update_weights,
-    mean_var_update_threshold,
-    map_relevance_factor,
-    map_alpha,
-    trainer,
+    machine,
 ):
     """Maximization step of the e-m algorithm."""
-    m_step_func = map_gmm_m_step if trainer == "map" else ml_gmm_m_step
+    m_step_func = map_gmm_m_step if machine.trainer == "map" else ml_gmm_m_step
     statistics = functools.reduce(operator.iadd, statistics)
     m_step_func(
         machine=machine,
         statistics=statistics,
-        update_means=update_means,
-        update_variances=update_variances,
-        update_weights=update_weights,
-        mean_var_update_threshold=mean_var_update_threshold,
-        reynolds_adaptation=map_relevance_factor is not None,
-        alpha=map_alpha,
-        relevance_factor=map_relevance_factor,
+        update_means=machine.update_means,
+        update_variances=machine.update_variances,
+        update_weights=machine.update_weights,
+        mean_var_update_threshold=machine.mean_var_update_threshold,
+        reynolds_adaptation=machine.map_relevance_factor is not None,
+        alpha=machine.map_alpha,
+        relevance_factor=machine.map_relevance_factor,
     )
     average_output = float(statistics.log_likelihood / statistics.t)
     return machine, average_output
@@ -684,12 +672,10 @@ class GMMMachine(BaseEstimator):
 
     def __eq__(self, other):
         return (
-            np.array_equal(self.means, other.means)
-            and np.array_equal(self.variances, other.variances)
-            and np.array_equal(
-                self.variance_thresholds, other.variance_thresholds
-            )
-            and np.array_equal(self.weights, other.weights)
+            np.allclose(self.means, other.means)
+            and np.allclose(self.variances, other.variances)
+            and np.allclose(self.variance_thresholds, other.variance_thresholds)
+            and np.allclose(self.weights, other.weights)
         )
 
     def is_similar_to(self, other, rtol=1e-5, atol=1e-8):
@@ -762,10 +748,7 @@ class GMMMachine(BaseEstimator):
         """
         return log_weighted_likelihood(
             data=data,
-            means=self.means,
-            variances=self.variances,
-            g_norms=self.g_norms,
-            log_weights=self.log_weights,
+            machine=self,
         )
 
     def log_likelihood(
@@ -786,10 +769,7 @@ class GMMMachine(BaseEstimator):
         """
         return log_likelihood(
             data=data,
-            means=self.means,
-            variances=self.variances,
-            g_norms=self.g_norms,
-            log_weights=self.log_weights,
+            machine=self,
         )
 
     def fit(self, X, y=None):
@@ -808,17 +788,6 @@ class GMMMachine(BaseEstimator):
             )
             self.variances = np.ones_like(self.means)
 
-        m_step_func = functools.partial(
-            m_step,
-            update_means=self.update_means,
-            update_variances=self.update_variances,
-            update_weights=self.update_weights,
-            mean_var_update_threshold=self.mean_var_update_threshold,
-            map_relevance_factor=self.map_relevance_factor,
-            map_alpha=self.map_alpha,
-            trainer=self.trainer,
-        )
-
         X = array_to_delayed_list(X, input_is_dask)
 
         average_output = 0
@@ -842,16 +811,12 @@ class GMMMachine(BaseEstimator):
                 stats = [
                     dask.delayed(e_step)(
                         data=xx,
-                        weights=self.weights,
-                        means=self.means,
-                        variances=self.variances,
-                        g_norms=self.g_norms,
-                        log_weights=self.log_weights,
+                        machine=self,
                     )
                     for xx in X
                 ]
                 new_machine, average_output = dask.compute(
-                    dask.delayed(m_step_func)(self, stats)
+                    dask.delayed(m_step)(stats, self)
                 )[0]
                 for attr in ["weights", "means", "variances"]:
                     setattr(self, attr, getattr(new_machine, attr))
@@ -859,14 +824,10 @@ class GMMMachine(BaseEstimator):
                 stats = [
                     e_step(
                         data=X,
-                        weights=self.weights,
-                        means=self.means,
-                        variances=self.variances,
-                        g_norms=self.g_norms,
-                        log_weights=self.log_weights,
+                        machine=self,
                     )
                 ]
-                _, average_output = m_step_func(self, stats)
+                _, average_output = m_step(stats, self)
 
             logger.debug(f"log likelihood = {average_output}")
             if step > 1:
@@ -897,12 +858,8 @@ class GMMMachine(BaseEstimator):
     def transform(self, X, **kwargs):
         """Returns the statistics for `X`."""
         return e_step(
-            X,
-            self.weights,
-            self.means,
-            self.variances,
-            self.g_norms,
-            self.log_weights,
+            data=X,
+            machine=self,
         )
 
     def _more_tags(self):
diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py
index 65743ce3b95ff9bfa6a5712800a023942316c22f..a8962637b7ff113f6df29abd83f4f82cb98c6c72 100644
--- a/bob/learn/em/test/test_gmm.py
+++ b/bob/learn/em/test/test_gmm.py
@@ -14,7 +14,6 @@ import tempfile
 
 from copy import deepcopy
 
-import dask.array as da
 import numpy as np
 
 from dask.distributed import Client
@@ -44,6 +43,33 @@ def multiprocess_dask_client():
         client.close()
 
 
+def loadGMM():
+    gmm = GMMMachine(n_gaussians=2)
+
+    gmm.weights = load_array(
+        resource_filename("bob.learn.em", "data/gmm.init_weights.hdf5")
+    )
+    gmm.means = load_array(
+        resource_filename("bob.learn.em", "data/gmm.init_means.hdf5")
+    )
+    gmm.variances = load_array(
+        resource_filename("bob.learn.em", "data/gmm.init_variances.hdf5")
+    )
+
+    return gmm
+
+
+def assert_gmm_equal(gmm1, gmm2):
+    """Asserts that two GMMs are equal"""
+    np.testing.assert_almost_equal(gmm1.weights, gmm2.weights)
+    np.testing.assert_almost_equal(gmm1.means, gmm2.means)
+    np.testing.assert_almost_equal(gmm1.variances, gmm2.variances)
+    np.testing.assert_almost_equal(
+        gmm1.variance_thresholds, gmm2.variance_thresholds
+    )
+    assert gmm1 == gmm2
+
+
 def test_GMMStats():
     # Test a GMMStats
     # Initializes a GMMStats
@@ -197,20 +223,16 @@ def test_GMMMachine():
     gmm6.variances = variances
     gmm6.variance_thresholds = varianceThresholds2
 
-    assert gmm == gmm2
+    assert_gmm_equal(gmm, gmm2)
     assert (gmm != gmm2) is False
     assert gmm.is_similar_to(gmm2)
     assert gmm != gmm3
-    assert (gmm == gmm3) is False
     assert gmm.is_similar_to(gmm3) is False
     assert gmm != gmm4
-    assert (gmm == gmm4) is False
     assert gmm.is_similar_to(gmm4) is False
     assert gmm != gmm5
-    assert (gmm == gmm5) is False
     assert gmm.is_similar_to(gmm5) is False
     assert gmm != gmm6
-    assert (gmm == gmm6) is False
     assert gmm.is_similar_to(gmm6) is False
 
     # Saving and loading
@@ -225,7 +247,7 @@ def test_GMMMachine():
         assert type(gmm1.update_weights) is np.bool_
         assert type(gmm1.trainer) is str
         assert gmm1.ubm is None
-        assert gmm == gmm1
+        assert_gmm_equal(gmm, gmm1)
         # Using load
         gmm1 = GMMMachine(n_gaussians=gmm.n_gaussians)
         gmm1.load(HDF5File(filename, "r"))
@@ -235,13 +257,13 @@ def test_GMMMachine():
         assert type(gmm1.update_weights) is np.bool_
         assert type(gmm1.trainer) is str
         assert gmm1.ubm is None
-        assert gmm == gmm1
+        assert_gmm_equal(gmm, gmm1)
 
     with tempfile.NamedTemporaryFile(suffix=".hdf5") as f:
         filename = f.name
         gmm.save(filename)
         gmm1 = GMMMachine.from_hdf5(filename)
-        assert gmm == gmm1
+        assert_gmm_equal(gmm, gmm1)
 
     # Weights
     n_gaussians = 5
@@ -286,11 +308,7 @@ def test_GMMMachine_stats():
 
     stats = gmm_module.e_step(
         arrayset,
-        gmm.weights,
-        gmm.means,
-        gmm.variances,
-        gmm.g_norms,
-        gmm.log_weights,
+        gmm,
     )
 
     stats_ref = GMMStats(n_gaussians=2, n_features=2)
@@ -365,11 +383,7 @@ def test_GMMStats_operations():
     # Populate the GMMStats
     stats = gmm_module.e_step(
         data,
-        machine.weights,
-        machine.means,
-        machine.variances,
-        machine.g_norms,
-        machine.log_weights,
+        machine,
     )
 
     # Check shapes
@@ -473,13 +487,15 @@ def test_gmm_kmeans_plusplus_init():
     data = np.array(
         [[1.5, 1], [1, 1.5], [-1, 0.5], [-1.5, 0], [2, 2], [2.5, 2.5]]
     )
-    machine = machine.fit(data)
-    expected_means = np.array([[2.25, 2.25], [-1.25, 0.25], [1.25, 1.25]])
-    expected_variances = np.array(
-        [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]]
-    )
-    np.testing.assert_almost_equal(machine.means, expected_means, decimal=3)
-    np.testing.assert_almost_equal(machine.variances, expected_variances)
+    for transform in (to_numpy, to_dask_array):
+        data = transform(data)
+        machine = machine.fit(data)
+        expected_means = np.array([[2.25, 2.25], [-1.25, 0.25], [1.25, 1.25]])
+        expected_variances = np.array(
+            [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]]
+        )
+        np.testing.assert_almost_equal(machine.means, expected_means, decimal=3)
+        np.testing.assert_almost_equal(machine.variances, expected_variances)
 
 
 def test_gmm_kmeans_parallel_init():
@@ -517,16 +533,18 @@ def test_likelihood():
     machine = GMMMachine(n_gaussians)
     machine.means = np.repeat([[0], [1], [-1]], 3, 1)
     machine.variances = np.ones_like(machine.means)
-    log_likelihood = machine.log_likelihood(data)
-    expected_ll = np.array(
-        [
-            -3.6519900964986527,
-            -3.83151883210222,
-            -3.83151883210222,
-            -5.344374066745753,
-        ]
-    )
-    np.testing.assert_almost_equal(log_likelihood, expected_ll)
+    for transform in (to_numpy, to_dask_array):
+        data = transform(data)
+        log_likelihood = machine.log_likelihood(data)
+        expected_ll = np.array(
+            [
+                -3.6519900964986527,
+                -3.83151883210222,
+                -3.83151883210222,
+                -5.344374066745753,
+            ]
+        )
+        np.testing.assert_almost_equal(log_likelihood, expected_ll)
 
 
 def test_likelihood_variance():
@@ -541,16 +559,18 @@ def test_likelihood_variance():
             [1, 1, 1],
         ]
     )
-    log_likelihood = machine.log_likelihood(data)
-    expected_ll = np.array(
-        [
-            -2.202846959440514,
-            -3.8699524542323793,
-            -4.229029034375473,
-            -6.940892214952679,
-        ]
-    )
-    np.testing.assert_almost_equal(log_likelihood, expected_ll)
+    for transform in (to_numpy, to_dask_array):
+        data = transform(data)
+        log_likelihood = machine.log_likelihood(data)
+        expected_ll = np.array(
+            [
+                -2.202846959440514,
+                -3.8699524542323793,
+                -4.229029034375473,
+                -6.940892214952679,
+            ]
+        )
+        np.testing.assert_almost_equal(log_likelihood, expected_ll)
 
 
 def test_likelihood_weight():
@@ -560,16 +580,18 @@ def test_likelihood_weight():
     machine.means = np.repeat([[0], [1], [-1]], 3, 1)
     machine.variances = np.ones_like(machine.means)
     machine.weights = [0.6, 0.1, 0.3]
-    log_likelihood = machine.log_likelihood(data)
-    expected_ll = np.array(
-        [
-            -4.206596356117164,
-            -3.492325679996329,
-            -3.634745457950943,
-            -6.49485678536014,
-        ]
-    )
-    np.testing.assert_almost_equal(log_likelihood, expected_ll)
+    for transform in (to_numpy, to_dask_array):
+        data = transform(data)
+        log_likelihood = machine.log_likelihood(data)
+        expected_ll = np.array(
+            [
+                -4.206596356117164,
+                -3.492325679996329,
+                -3.634745457950943,
+                -6.49485678536014,
+            ]
+        )
+        np.testing.assert_almost_equal(log_likelihood, expected_ll)
 
 
 def test_GMMMachine_object():
@@ -618,22 +640,11 @@ def test_ml_em():
 
     stats = gmm_module.e_step(
         data,
-        machine.weights,
-        machine.means,
-        machine.variances,
-        machine.g_norms,
-        machine.log_weights,
+        machine,
     )
     gmm_module.m_step(
-        machine,
         [stats],
-        machine.update_means,
-        machine.update_variances,
-        machine.update_weights,
-        machine.mean_var_update_threshold,
-        machine.map_relevance_factor,
-        machine.map_alpha,
-        machine.trainer,
+        machine,
     )
 
     expected_means = np.array([[1.5, 1.5, 2.0], [7.0, 8.0, 8.0]])
@@ -674,22 +685,11 @@ def test_map_em():
 
     stats = gmm_module.e_step(
         post_data,
-        machine.weights,
-        machine.means,
-        machine.variances,
-        machine.g_norms,
-        machine.log_weights,
+        machine,
     )
     gmm_module.m_step(
-        machine,
         [stats],
-        machine.update_means,
-        machine.update_variances,
-        machine.update_weights,
-        machine.mean_var_update_threshold,
-        machine.map_relevance_factor,
-        machine.map_alpha,
-        machine.trainer,
+        machine,
     )
 
     expected_means = np.array(
@@ -718,27 +718,31 @@ def test_ml_transformer():
     machine.means = np.array([[2, 2, 2], [8, 8, 8]])
     machine.variances = np.ones_like(machine.means)
 
-    machine = machine.fit(data)
-
-    expected_means = np.array([[1.5, 1.5, 2.0], [7.0, 8.0, 8.0]])
-    np.testing.assert_almost_equal(machine.means, expected_means)
-    expected_weights = np.array([2 / 5, 3 / 5])
-    np.testing.assert_almost_equal(machine.weights, expected_weights)
-    eps = np.finfo(float).eps
-    expected_variances = np.array([[1 / 4, 1 / 4, eps], [eps, 2 / 3, 2 / 3]])
-    np.testing.assert_almost_equal(machine.variances, expected_variances)
+    for transform in (to_numpy, to_dask_array):
+        data = transform(data)
+        machine = machine.fit(data)
+
+        expected_means = np.array([[1.5, 1.5, 2.0], [7.0, 8.0, 8.0]])
+        np.testing.assert_almost_equal(machine.means, expected_means)
+        expected_weights = np.array([2 / 5, 3 / 5])
+        np.testing.assert_almost_equal(machine.weights, expected_weights)
+        eps = np.finfo(float).eps
+        expected_variances = np.array(
+            [[1 / 4, 1 / 4, eps], [eps, 2 / 3, 2 / 3]]
+        )
+        np.testing.assert_almost_equal(machine.variances, expected_variances)
 
-    stats = machine.transform(test_data)
+        stats = machine.transform(test_data)
 
-    expected_stats = GMMStats(n_gaussians, n_features)
-    expected_stats.init_fields(
-        log_likelihood=-6755399441055685.0,
-        t=test_data.shape[0],
-        n=np.array([2, 2], dtype=float),
-        sum_px=np.array([[2, 2, 3], [16, 17, 17]], dtype=float),
-        sum_pxx=np.array([[2, 2, 5], [128, 145, 145]], dtype=float),
-    )
-    assert stats.is_similar_to(expected_stats)
+        expected_stats = GMMStats(n_gaussians, n_features)
+        expected_stats.init_fields(
+            log_likelihood=-6755399441055685.0,
+            t=test_data.shape[0],
+            n=np.array([2, 2], dtype=float),
+            sum_px=np.array([[2, 2, 3], [16, 17, 17]], dtype=float),
+            sum_pxx=np.array([[2, 2, 5], [128, 145, 145]], dtype=float),
+        )
+        assert stats.is_similar_to(expected_stats)
 
 
 def test_map_transformer():
@@ -762,79 +766,69 @@ def test_map_transformer():
         update_weights=True,
     )
 
-    machine = machine.fit(post_data)
+    for transform in (to_numpy, to_dask_array):
+        post_data = transform(post_data)
+        machine = machine.fit(post_data)
 
-    expected_means = np.array(
-        [[1.83333333, 1.83333333, 2.0], [7.57142857, 8, 8]]
-    )
-    np.testing.assert_almost_equal(machine.means, expected_means)
-    eps = np.finfo(float).eps
-    expected_vars = np.array([[eps, eps, eps], [eps, eps, eps]])
-    np.testing.assert_almost_equal(machine.variances, expected_vars)
-    expected_weights = np.array([0.46226415, 0.53773585])
-    np.testing.assert_almost_equal(machine.weights, expected_weights)
-
-    stats = machine.transform(test_data)
-
-    expected_stats = GMMStats(n_gaussians, n_features)
-    expected_stats.init_fields(
-        log_likelihood=-1.3837590691807108e16,
-        t=test_data.shape[0],
-        n=np.array([2, 2], dtype=float),
-        sum_px=np.array([[2, 2, 3], [16, 17, 17]], dtype=float),
-        sum_pxx=np.array([[2, 2, 5], [128, 145, 145]], dtype=float),
-    )
-    assert stats.is_similar_to(expected_stats)
+        expected_means = np.array(
+            [[1.83333333, 1.83333333, 2.0], [7.57142857, 8, 8]]
+        )
+        np.testing.assert_almost_equal(machine.means, expected_means)
+        eps = np.finfo(float).eps
+        expected_vars = np.array([[eps, eps, eps], [eps, eps, eps]])
+        np.testing.assert_almost_equal(machine.variances, expected_vars)
+        expected_weights = np.array([0.46226415, 0.53773585])
+        np.testing.assert_almost_equal(machine.weights, expected_weights)
+
+        stats = machine.transform(test_data)
+
+        expected_stats = GMMStats(n_gaussians, n_features)
+        expected_stats.init_fields(
+            log_likelihood=-1.3837590691807108e16,
+            t=test_data.shape[0],
+            n=np.array([2, 2], dtype=float),
+            sum_px=np.array([[2, 2, 3], [16, 17, 17]], dtype=float),
+            sum_pxx=np.array([[2, 2, 5], [128, 145, 145]], dtype=float),
+        )
+        assert stats.is_similar_to(expected_stats)
 
 
 # Tests from `test_em.py`
 
 
-def loadGMM():
-    gmm = GMMMachine(n_gaussians=2)
-
-    gmm.weights = load_array(
-        resource_filename("bob.learn.em", "data/gmm.init_weights.hdf5")
-    )
-    gmm.means = load_array(
-        resource_filename("bob.learn.em", "data/gmm.init_means.hdf5")
-    )
-    gmm.variances = load_array(
-        resource_filename("bob.learn.em", "data/gmm.init_variances.hdf5")
-    )
-
-    return gmm
-
-
 def test_gmm_ML_1():
     """Trains a GMMMachine with ML_GMMTrainer"""
     ar = load_array(
         resource_filename("bob.learn.em", "data/faithful.torch3_f64.hdf5")
     )
-    gmm = loadGMM()
+    gmm_ref = GMMMachine.from_hdf5(
+        HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "r")
+    )
 
-    # test rng handling
-    gmm.convergence_threshold = 0.001
-    gmm.update_means = True
-    gmm.update_variances = True
-    gmm.update_weights = True
-    gmm.random_state = np.random.RandomState(seed=12345)
-    gmm = gmm.fit(ar)
+    for transform in (to_numpy, to_dask_array):
+        ar = transform(ar)
 
-    gmm = loadGMM()
-    gmm.convergence_threshold = 0.001
-    gmm.update_means = True
-    gmm.update_variances = True
-    gmm.update_weights = True
-    gmm = gmm.fit(ar)
+        gmm = loadGMM()
 
-    # Generate reference
-    # gmm.save(HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "w"))
+        # test rng handling
+        gmm.convergence_threshold = 0.001
+        gmm.update_means = True
+        gmm.update_variances = True
+        gmm.update_weights = True
+        gmm.random_state = np.random.RandomState(seed=12345)
+        gmm = gmm.fit(ar)
 
-    gmm_ref = GMMMachine.from_hdf5(
-        HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "r")
-    )
-    assert gmm == gmm_ref
+        gmm = loadGMM()
+        gmm.convergence_threshold = 0.001
+        gmm.update_means = True
+        gmm.update_variances = True
+        gmm.update_weights = True
+        # Generate reference
+        # gmm.save(HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "w"))
+
+        gmm = gmm.fit(ar)
+
+        assert_gmm_equal(gmm, gmm_ref)
 
 
 def test_gmm_ML_2():
@@ -843,33 +837,6 @@ def test_gmm_ML_2():
         resource_filename("bob.learn.em", "data/dataNormalized.hdf5")
     )
 
-    # Initialize GMMMachine
-    gmm = GMMMachine(n_gaussians=5)
-    gmm.means = load_array(
-        resource_filename("bob.learn.em", "data/meansAfterKMeans.hdf5")
-    ).astype("float64")
-    gmm.variances = load_array(
-        resource_filename("bob.learn.em", "data/variancesAfterKMeans.hdf5")
-    ).astype("float64")
-    gmm.weights = np.exp(
-        load_array(
-            resource_filename("bob.learn.em", "data/weightsAfterKMeans.hdf5")
-        ).astype("float64")
-    )
-
-    threshold = 0.001
-    gmm.variance_thresholds = threshold
-
-    # Initialize ML Trainer
-    gmm.mean_var_update_threshold = 0.001
-    gmm.max_fitting_steps = 25
-    gmm.convergence_threshold = 0.000001
-    gmm.update_means = True
-    gmm.update_variances = True
-    gmm.update_weights = True
-
-    # Run ML
-    gmm = gmm.fit(ar)
     # Test results
     # Load torch3vision reference
     meansML_ref = load_array(
@@ -882,10 +849,42 @@ def test_gmm_ML_2():
         resource_filename("bob.learn.em", "data/weightsAfterML.hdf5")
     )
 
-    # Compare to current results
-    np.testing.assert_allclose(gmm.means, meansML_ref, atol=3e-3)
-    np.testing.assert_allclose(gmm.variances, variancesML_ref, atol=3e-3)
-    np.testing.assert_allclose(gmm.weights, weightsML_ref, atol=1e-4)
+    for transform in (to_numpy, to_dask_array):
+        ar = transform(ar)
+        # Initialize GMMMachine
+        gmm = GMMMachine(n_gaussians=5)
+        gmm.means = load_array(
+            resource_filename("bob.learn.em", "data/meansAfterKMeans.hdf5")
+        ).astype("float64")
+        gmm.variances = load_array(
+            resource_filename("bob.learn.em", "data/variancesAfterKMeans.hdf5")
+        ).astype("float64")
+        gmm.weights = np.exp(
+            load_array(
+                resource_filename(
+                    "bob.learn.em", "data/weightsAfterKMeans.hdf5"
+                )
+            ).astype("float64")
+        )
+
+        threshold = 0.001
+        gmm.variance_thresholds = threshold
+
+        # Initialize ML Trainer
+        gmm.mean_var_update_threshold = 0.001
+        gmm.max_fitting_steps = 25
+        gmm.convergence_threshold = 0.000001
+        gmm.update_means = True
+        gmm.update_variances = True
+        gmm.update_weights = True
+
+        # Run ML
+        gmm = gmm.fit(ar)
+
+        # Compare to current results
+        np.testing.assert_allclose(gmm.means, meansML_ref, atol=3e-3)
+        np.testing.assert_allclose(gmm.variances, variancesML_ref, atol=3e-3)
+        np.testing.assert_allclose(gmm.weights, weightsML_ref, atol=1e-4)
 
 
 def test_gmm_MAP_1():
@@ -920,8 +919,6 @@ def test_gmm_MAP_1():
     gmm.update_variances = False
     gmm.update_weights = False
 
-    gmm = gmm.fit(ar)
-
     # Generate reference
     # gmm.save(HDF5File(resource_filename("bob.learn.em", "data/gmm_MAP.hdf5"), "w"))
 
@@ -929,9 +926,15 @@ def test_gmm_MAP_1():
         HDF5File(resource_filename("bob.learn.em", "data/gmm_MAP.hdf5"), "r")
     )
 
-    np.testing.assert_almost_equal(gmm.means, gmm_ref.means, decimal=3)
-    np.testing.assert_almost_equal(gmm.variances, gmm_ref.variances, decimal=3)
-    np.testing.assert_almost_equal(gmm.weights, gmm_ref.weights, decimal=3)
+    for transform in (to_numpy, to_dask_array):
+        ar = transform(ar)
+        gmm = gmm.fit(ar)
+
+        np.testing.assert_almost_equal(gmm.means, gmm_ref.means, decimal=3)
+        np.testing.assert_almost_equal(
+            gmm.variances, gmm_ref.variances, decimal=3
+        )
+        np.testing.assert_almost_equal(gmm.weights, gmm_ref.weights, decimal=3)
 
 
 def test_gmm_MAP_2():
@@ -964,14 +967,16 @@ def test_gmm_MAP_2():
     gmm_adapted.variances = variances
     gmm_adapted.weights = weights
 
-    gmm_adapted = gmm_adapted.fit(data)
-
     new_means = load_array(
         resource_filename("bob.learn.em", "data/new_adapted_mean.hdf5")
     )
 
-    # Compare to matlab reference
-    np.testing.assert_allclose(new_means.T, gmm_adapted.means, rtol=1e-4)
+    for transform in (to_numpy, to_dask_array):
+        data = transform(data)
+        gmm_adapted = gmm_adapted.fit(data)
+
+        # Compare to matlab reference
+        np.testing.assert_allclose(new_means.T, gmm_adapted.means, rtol=1e-4)
 
 
 def test_gmm_MAP_3():
@@ -1011,9 +1016,6 @@ def test_gmm_MAP_3():
     )
     gmm.variance_thresholds = threshold
 
-    # Train
-    gmm = gmm.fit(ar)
-
     # Test results
     # Load torch3vision reference
     meansMAP_ref = load_array(
@@ -1026,13 +1028,18 @@ def test_gmm_MAP_3():
         resource_filename("bob.learn.em", "data/weightsAfterMAP.hdf5")
     )
 
-    # Compare to current results
-    # Gaps are quite large. This might be explained by the fact that there is no
-    # adaptation of a given Gaussian in torch3 when the corresponding responsibilities
-    # are below the responsibilities threshold
-    np.testing.assert_allclose(gmm.means, meansMAP_ref, atol=2e-1)
-    np.testing.assert_allclose(gmm.variances, variancesMAP_ref, atol=1e-4)
-    np.testing.assert_allclose(gmm.weights, weightsMAP_ref, atol=1e-4)
+    for transform in (to_numpy, to_dask_array):
+        ar = transform(ar)
+        # Train
+        gmm = gmm.fit(ar)
+
+        # Compare to current results
+        # Gaps are quite large. This might be explained by the fact that there is no
+        # adaptation of a given Gaussian in torch3 when the corresponding responsibilities
+        # are below the responsibilities threshold
+        np.testing.assert_allclose(gmm.means, meansMAP_ref, atol=2e-1)
+        np.testing.assert_allclose(gmm.variances, variancesMAP_ref, atol=1e-4)
+        np.testing.assert_allclose(gmm.weights, weightsMAP_ref, atol=1e-4)
 
 
 def test_gmm_test():
@@ -1058,126 +1065,10 @@ def test_gmm_test():
 
     # Test against the model
     score_mean_ref = -1.50379e06
-    score = gmm.log_likelihood(ar).sum()
-    score /= len(ar)
-
-    # Compare current results to torch3vision
-    assert abs(score - score_mean_ref) / score_mean_ref < 1e-4
-
-
-def test_gmm_ML_dask():
-    # Trains a GMMMachine with dask array data; compares to a reference
-
-    ar = da.array(
-        load_array(
-            resource_filename("bob.learn.em", "data/dataNormalized.hdf5")
-        )
-    )
-
-    # Initialize GMMMachine
-    gmm = GMMMachine(n_gaussians=5)
-    gmm.means = load_array(
-        resource_filename("bob.learn.em", "data/meansAfterKMeans.hdf5")
-    ).astype("float64")
-    gmm.variances = load_array(
-        resource_filename("bob.learn.em", "data/variancesAfterKMeans.hdf5")
-    ).astype("float64")
-    gmm.weights = np.exp(
-        load_array(
-            resource_filename("bob.learn.em", "data/weightsAfterKMeans.hdf5")
-        ).astype("float64")
-    )
-
-    threshold = 0.001
-    gmm.variance_thresholds = threshold
-
-    # Initialize ML Trainer
-    gmm.mean_var_update_threshold = 0.001
-    gmm.max_fitting_steps = 25
-    gmm.convergence_threshold = 0.00001
-    gmm.update_means = True
-    gmm.update_variances = True
-    gmm.update_weights = True
-
-    # Run ML
-    gmm.fit(ar)
-
-    # Test results
-    # Load torch3vision reference
-    meansML_ref = load_array(
-        resource_filename("bob.learn.em", "data/meansAfterML.hdf5")
-    )
-    variancesML_ref = load_array(
-        resource_filename("bob.learn.em", "data/variancesAfterML.hdf5")
-    )
-    weightsML_ref = load_array(
-        resource_filename("bob.learn.em", "data/weightsAfterML.hdf5")
-    )
-
-    # Compare to current results
-    np.testing.assert_allclose(gmm.means, meansML_ref, atol=3e-3)
-    np.testing.assert_allclose(gmm.variances, variancesML_ref, atol=3e-3)
-    np.testing.assert_allclose(gmm.weights, weightsML_ref, atol=1e-4)
-
-
-def test_gmm_MAP_dask():
-    # Test a GMMMachine for MAP with a dask array as data.
-    ar = da.array(
-        load_array(resource_filename("bob.learn.em", "data/dataforMAP.hdf5"))
-    )
-
-    # Initialize GMMMachine
-    n_gaussians = 5
-    prior_gmm = GMMMachine(n_gaussians)
-    prior_gmm.means = load_array(
-        resource_filename("bob.learn.em", "data/meansAfterML.hdf5")
-    )
-    prior_gmm.variances = load_array(
-        resource_filename("bob.learn.em", "data/variancesAfterML.hdf5")
-    )
-    prior_gmm.weights = load_array(
-        resource_filename("bob.learn.em", "data/weightsAfterML.hdf5")
-    )
-
-    threshold = 0.001
-    prior_gmm.variance_thresholds = threshold
-
-    # Initialize MAP Trainer
-    prior = 0.001
-    accuracy = 0.00001
-    gmm = GMMMachine(
-        n_gaussians,
-        trainer="map",
-        ubm=prior_gmm,
-        convergence_threshold=prior,
-        max_fitting_steps=1,
-        update_means=True,
-        update_variances=False,
-        update_weights=False,
-        mean_var_update_threshold=accuracy,
-        map_relevance_factor=None,
-    )
-    gmm.variance_thresholds = threshold
-
-    # Train
-    gmm = gmm.fit(ar)
-
-    # Test results
-    # Load torch3vision reference
-    meansMAP_ref = load_array(
-        resource_filename("bob.learn.em", "data/meansAfterMAP.hdf5")
-    )
-    variancesMAP_ref = load_array(
-        resource_filename("bob.learn.em", "data/variancesAfterMAP.hdf5")
-    )
-    weightsMAP_ref = load_array(
-        resource_filename("bob.learn.em", "data/weightsAfterMAP.hdf5")
-    )
+    for transform in (to_numpy, to_dask_array):
+        ar = transform(ar)
+        score = gmm.log_likelihood(ar).sum()
+        score /= len(ar)
 
-    # Compare to current results
-    # Gaps are quite large. This might be explained by the fact that there is no
-    # adaptation of a given Gaussian in torch3 when the corresponding responsibilities
-    # are below the responsibilities threshold
-    np.testing.assert_allclose(gmm.means, meansMAP_ref, atol=2e-1)
-    np.testing.assert_allclose(gmm.variances, variancesMAP_ref, atol=1e-4)
-    np.testing.assert_allclose(gmm.weights, weightsMAP_ref, atol=1e-4)
+        # Compare current results to torch3vision
+        assert abs(score - score_mean_ref) / score_mean_ref < 1e-4
diff --git a/bob/learn/em/test/test_kmeans.py b/bob/learn/em/test/test_kmeans.py
index 8edc205ab9ecf2f6f15a8c9e4213f57cb219f804..613a562bf864b4d66b31ae5a410fc6ad6298e164 100644
--- a/bob/learn/em/test/test_kmeans.py
+++ b/bob/learn/em/test/test_kmeans.py
@@ -33,7 +33,7 @@ def to_dask_array(*args):
     for x in args:
         x = np.asarray(x)
         chunks = list(x.shape)
-        chunks[0] //= 2
+        chunks[0] = int(np.ceil(chunks[0] / 2))
         result.append(da.from_array(x, chunks=chunks))
     if len(result) == 1:
         return result[0]