From fe763de3f5f3d035a3c3f58a39be0332221b5e5c Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Thu, 24 Mar 2022 19:03:59 +0100
Subject: [PATCH] [gmm] add more dask tests, fixes #38

---
 bob/learn/em/gmm.py              |  10 +-
 bob/learn/em/test/test_gmm.py    | 511 +++++++++++++------------------
 bob/learn/em/test/test_kmeans.py |   2 +-
 3 files changed, 221 insertions(+), 302 deletions(-)

diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index 7629399..38908cc 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -672,12 +672,10 @@ class GMMMachine(BaseEstimator):
 
     def __eq__(self, other):
         return (
-            np.array_equal(self.means, other.means)
-            and np.array_equal(self.variances, other.variances)
-            and np.array_equal(
-                self.variance_thresholds, other.variance_thresholds
-            )
-            and np.array_equal(self.weights, other.weights)
+            np.allclose(self.means, other.means)
+            and np.allclose(self.variances, other.variances)
+            and np.allclose(self.variance_thresholds, other.variance_thresholds)
+            and np.allclose(self.weights, other.weights)
         )
 
     def is_similar_to(self, other, rtol=1e-5, atol=1e-8):
diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py
index 6e4b5d0..a896263 100644
--- a/bob/learn/em/test/test_gmm.py
+++ b/bob/learn/em/test/test_gmm.py
@@ -14,7 +14,6 @@ import tempfile
 
 from copy import deepcopy
 
-import dask.array as da
 import numpy as np
 
 from dask.distributed import Client
@@ -44,6 +43,33 @@ def multiprocess_dask_client():
         client.close()
 
 
+def loadGMM():
+    gmm = GMMMachine(n_gaussians=2)
+
+    gmm.weights = load_array(
+        resource_filename("bob.learn.em", "data/gmm.init_weights.hdf5")
+    )
+    gmm.means = load_array(
+        resource_filename("bob.learn.em", "data/gmm.init_means.hdf5")
+    )
+    gmm.variances = load_array(
+        resource_filename("bob.learn.em", "data/gmm.init_variances.hdf5")
+    )
+
+    return gmm
+
+
+def assert_gmm_equal(gmm1, gmm2):
+    """Asserts that two GMMs are equal"""
+    np.testing.assert_almost_equal(gmm1.weights, gmm2.weights)
+    np.testing.assert_almost_equal(gmm1.means, gmm2.means)
+    np.testing.assert_almost_equal(gmm1.variances, gmm2.variances)
+    np.testing.assert_almost_equal(
+        gmm1.variance_thresholds, gmm2.variance_thresholds
+    )
+    assert gmm1 == gmm2
+
+
 def test_GMMStats():
     # Test a GMMStats
     # Initializes a GMMStats
@@ -197,20 +223,16 @@ def test_GMMMachine():
     gmm6.variances = variances
     gmm6.variance_thresholds = varianceThresholds2
 
-    assert gmm == gmm2
+    assert_gmm_equal(gmm, gmm2)
     assert (gmm != gmm2) is False
     assert gmm.is_similar_to(gmm2)
     assert gmm != gmm3
-    assert (gmm == gmm3) is False
     assert gmm.is_similar_to(gmm3) is False
     assert gmm != gmm4
-    assert (gmm == gmm4) is False
     assert gmm.is_similar_to(gmm4) is False
     assert gmm != gmm5
-    assert (gmm == gmm5) is False
     assert gmm.is_similar_to(gmm5) is False
     assert gmm != gmm6
-    assert (gmm == gmm6) is False
     assert gmm.is_similar_to(gmm6) is False
 
     # Saving and loading
@@ -225,7 +247,7 @@ def test_GMMMachine():
         assert type(gmm1.update_weights) is np.bool_
         assert type(gmm1.trainer) is str
         assert gmm1.ubm is None
-        assert gmm == gmm1
+        assert_gmm_equal(gmm, gmm1)
         # Using load
         gmm1 = GMMMachine(n_gaussians=gmm.n_gaussians)
         gmm1.load(HDF5File(filename, "r"))
@@ -235,13 +257,13 @@ def test_GMMMachine():
         assert type(gmm1.update_weights) is np.bool_
         assert type(gmm1.trainer) is str
         assert gmm1.ubm is None
-        assert gmm == gmm1
+        assert_gmm_equal(gmm, gmm1)
 
     with tempfile.NamedTemporaryFile(suffix=".hdf5") as f:
         filename = f.name
         gmm.save(filename)
         gmm1 = GMMMachine.from_hdf5(filename)
-        assert gmm == gmm1
+        assert_gmm_equal(gmm, gmm1)
 
     # Weights
     n_gaussians = 5
@@ -465,13 +487,15 @@ def test_gmm_kmeans_plusplus_init():
     data = np.array(
         [[1.5, 1], [1, 1.5], [-1, 0.5], [-1.5, 0], [2, 2], [2.5, 2.5]]
     )
-    machine = machine.fit(data)
-    expected_means = np.array([[2.25, 2.25], [-1.25, 0.25], [1.25, 1.25]])
-    expected_variances = np.array(
-        [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]]
-    )
-    np.testing.assert_almost_equal(machine.means, expected_means, decimal=3)
-    np.testing.assert_almost_equal(machine.variances, expected_variances)
+    for transform in (to_numpy, to_dask_array):
+        data = transform(data)
+        machine = machine.fit(data)
+        expected_means = np.array([[2.25, 2.25], [-1.25, 0.25], [1.25, 1.25]])
+        expected_variances = np.array(
+            [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]]
+        )
+        np.testing.assert_almost_equal(machine.means, expected_means, decimal=3)
+        np.testing.assert_almost_equal(machine.variances, expected_variances)
 
 
 def test_gmm_kmeans_parallel_init():
@@ -509,16 +533,18 @@ def test_likelihood():
     machine = GMMMachine(n_gaussians)
     machine.means = np.repeat([[0], [1], [-1]], 3, 1)
     machine.variances = np.ones_like(machine.means)
-    log_likelihood = machine.log_likelihood(data)
-    expected_ll = np.array(
-        [
-            -3.6519900964986527,
-            -3.83151883210222,
-            -3.83151883210222,
-            -5.344374066745753,
-        ]
-    )
-    np.testing.assert_almost_equal(log_likelihood, expected_ll)
+    for transform in (to_numpy, to_dask_array):
+        data = transform(data)
+        log_likelihood = machine.log_likelihood(data)
+        expected_ll = np.array(
+            [
+                -3.6519900964986527,
+                -3.83151883210222,
+                -3.83151883210222,
+                -5.344374066745753,
+            ]
+        )
+        np.testing.assert_almost_equal(log_likelihood, expected_ll)
 
 
 def test_likelihood_variance():
@@ -533,16 +559,18 @@ def test_likelihood_variance():
             [1, 1, 1],
         ]
     )
-    log_likelihood = machine.log_likelihood(data)
-    expected_ll = np.array(
-        [
-            -2.202846959440514,
-            -3.8699524542323793,
-            -4.229029034375473,
-            -6.940892214952679,
-        ]
-    )
-    np.testing.assert_almost_equal(log_likelihood, expected_ll)
+    for transform in (to_numpy, to_dask_array):
+        data = transform(data)
+        log_likelihood = machine.log_likelihood(data)
+        expected_ll = np.array(
+            [
+                -2.202846959440514,
+                -3.8699524542323793,
+                -4.229029034375473,
+                -6.940892214952679,
+            ]
+        )
+        np.testing.assert_almost_equal(log_likelihood, expected_ll)
 
 
 def test_likelihood_weight():
@@ -552,16 +580,18 @@ def test_likelihood_weight():
     machine.means = np.repeat([[0], [1], [-1]], 3, 1)
     machine.variances = np.ones_like(machine.means)
     machine.weights = [0.6, 0.1, 0.3]
-    log_likelihood = machine.log_likelihood(data)
-    expected_ll = np.array(
-        [
-            -4.206596356117164,
-            -3.492325679996329,
-            -3.634745457950943,
-            -6.49485678536014,
-        ]
-    )
-    np.testing.assert_almost_equal(log_likelihood, expected_ll)
+    for transform in (to_numpy, to_dask_array):
+        data = transform(data)
+        log_likelihood = machine.log_likelihood(data)
+        expected_ll = np.array(
+            [
+                -4.206596356117164,
+                -3.492325679996329,
+                -3.634745457950943,
+                -6.49485678536014,
+            ]
+        )
+        np.testing.assert_almost_equal(log_likelihood, expected_ll)
 
 
 def test_GMMMachine_object():
@@ -688,27 +718,31 @@ def test_ml_transformer():
     machine.means = np.array([[2, 2, 2], [8, 8, 8]])
     machine.variances = np.ones_like(machine.means)
 
-    machine = machine.fit(data)
-
-    expected_means = np.array([[1.5, 1.5, 2.0], [7.0, 8.0, 8.0]])
-    np.testing.assert_almost_equal(machine.means, expected_means)
-    expected_weights = np.array([2 / 5, 3 / 5])
-    np.testing.assert_almost_equal(machine.weights, expected_weights)
-    eps = np.finfo(float).eps
-    expected_variances = np.array([[1 / 4, 1 / 4, eps], [eps, 2 / 3, 2 / 3]])
-    np.testing.assert_almost_equal(machine.variances, expected_variances)
+    for transform in (to_numpy, to_dask_array):
+        data = transform(data)
+        machine = machine.fit(data)
+
+        expected_means = np.array([[1.5, 1.5, 2.0], [7.0, 8.0, 8.0]])
+        np.testing.assert_almost_equal(machine.means, expected_means)
+        expected_weights = np.array([2 / 5, 3 / 5])
+        np.testing.assert_almost_equal(machine.weights, expected_weights)
+        eps = np.finfo(float).eps
+        expected_variances = np.array(
+            [[1 / 4, 1 / 4, eps], [eps, 2 / 3, 2 / 3]]
+        )
+        np.testing.assert_almost_equal(machine.variances, expected_variances)
 
-    stats = machine.transform(test_data)
+        stats = machine.transform(test_data)
 
-    expected_stats = GMMStats(n_gaussians, n_features)
-    expected_stats.init_fields(
-        log_likelihood=-6755399441055685.0,
-        t=test_data.shape[0],
-        n=np.array([2, 2], dtype=float),
-        sum_px=np.array([[2, 2, 3], [16, 17, 17]], dtype=float),
-        sum_pxx=np.array([[2, 2, 5], [128, 145, 145]], dtype=float),
-    )
-    assert stats.is_similar_to(expected_stats)
+        expected_stats = GMMStats(n_gaussians, n_features)
+        expected_stats.init_fields(
+            log_likelihood=-6755399441055685.0,
+            t=test_data.shape[0],
+            n=np.array([2, 2], dtype=float),
+            sum_px=np.array([[2, 2, 3], [16, 17, 17]], dtype=float),
+            sum_pxx=np.array([[2, 2, 5], [128, 145, 145]], dtype=float),
+        )
+        assert stats.is_similar_to(expected_stats)
 
 
 def test_map_transformer():
@@ -732,79 +766,69 @@ def test_map_transformer():
         update_weights=True,
     )
 
-    machine = machine.fit(post_data)
+    for transform in (to_numpy, to_dask_array):
+        post_data = transform(post_data)
+        machine = machine.fit(post_data)
 
-    expected_means = np.array(
-        [[1.83333333, 1.83333333, 2.0], [7.57142857, 8, 8]]
-    )
-    np.testing.assert_almost_equal(machine.means, expected_means)
-    eps = np.finfo(float).eps
-    expected_vars = np.array([[eps, eps, eps], [eps, eps, eps]])
-    np.testing.assert_almost_equal(machine.variances, expected_vars)
-    expected_weights = np.array([0.46226415, 0.53773585])
-    np.testing.assert_almost_equal(machine.weights, expected_weights)
-
-    stats = machine.transform(test_data)
-
-    expected_stats = GMMStats(n_gaussians, n_features)
-    expected_stats.init_fields(
-        log_likelihood=-1.3837590691807108e16,
-        t=test_data.shape[0],
-        n=np.array([2, 2], dtype=float),
-        sum_px=np.array([[2, 2, 3], [16, 17, 17]], dtype=float),
-        sum_pxx=np.array([[2, 2, 5], [128, 145, 145]], dtype=float),
-    )
-    assert stats.is_similar_to(expected_stats)
+        expected_means = np.array(
+            [[1.83333333, 1.83333333, 2.0], [7.57142857, 8, 8]]
+        )
+        np.testing.assert_almost_equal(machine.means, expected_means)
+        eps = np.finfo(float).eps
+        expected_vars = np.array([[eps, eps, eps], [eps, eps, eps]])
+        np.testing.assert_almost_equal(machine.variances, expected_vars)
+        expected_weights = np.array([0.46226415, 0.53773585])
+        np.testing.assert_almost_equal(machine.weights, expected_weights)
+
+        stats = machine.transform(test_data)
+
+        expected_stats = GMMStats(n_gaussians, n_features)
+        expected_stats.init_fields(
+            log_likelihood=-1.3837590691807108e16,
+            t=test_data.shape[0],
+            n=np.array([2, 2], dtype=float),
+            sum_px=np.array([[2, 2, 3], [16, 17, 17]], dtype=float),
+            sum_pxx=np.array([[2, 2, 5], [128, 145, 145]], dtype=float),
+        )
+        assert stats.is_similar_to(expected_stats)
 
 
 # Tests from `test_em.py`
 
 
-def loadGMM():
-    gmm = GMMMachine(n_gaussians=2)
-
-    gmm.weights = load_array(
-        resource_filename("bob.learn.em", "data/gmm.init_weights.hdf5")
-    )
-    gmm.means = load_array(
-        resource_filename("bob.learn.em", "data/gmm.init_means.hdf5")
-    )
-    gmm.variances = load_array(
-        resource_filename("bob.learn.em", "data/gmm.init_variances.hdf5")
-    )
-
-    return gmm
-
-
 def test_gmm_ML_1():
     """Trains a GMMMachine with ML_GMMTrainer"""
     ar = load_array(
         resource_filename("bob.learn.em", "data/faithful.torch3_f64.hdf5")
     )
-    gmm = loadGMM()
+    gmm_ref = GMMMachine.from_hdf5(
+        HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "r")
+    )
 
-    # test rng handling
-    gmm.convergence_threshold = 0.001
-    gmm.update_means = True
-    gmm.update_variances = True
-    gmm.update_weights = True
-    gmm.random_state = np.random.RandomState(seed=12345)
-    gmm = gmm.fit(ar)
+    for transform in (to_numpy, to_dask_array):
+        ar = transform(ar)
 
-    gmm = loadGMM()
-    gmm.convergence_threshold = 0.001
-    gmm.update_means = True
-    gmm.update_variances = True
-    gmm.update_weights = True
-    gmm = gmm.fit(ar)
+        gmm = loadGMM()
 
-    # Generate reference
-    # gmm.save(HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "w"))
+        # test rng handling
+        gmm.convergence_threshold = 0.001
+        gmm.update_means = True
+        gmm.update_variances = True
+        gmm.update_weights = True
+        gmm.random_state = np.random.RandomState(seed=12345)
+        gmm = gmm.fit(ar)
 
-    gmm_ref = GMMMachine.from_hdf5(
-        HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "r")
-    )
-    assert gmm == gmm_ref
+        gmm = loadGMM()
+        gmm.convergence_threshold = 0.001
+        gmm.update_means = True
+        gmm.update_variances = True
+        gmm.update_weights = True
+        # Generate reference
+        # gmm.save(HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "w"))
+
+        gmm = gmm.fit(ar)
+
+        assert_gmm_equal(gmm, gmm_ref)
 
 
 def test_gmm_ML_2():
@@ -813,33 +837,6 @@ def test_gmm_ML_2():
         resource_filename("bob.learn.em", "data/dataNormalized.hdf5")
     )
 
-    # Initialize GMMMachine
-    gmm = GMMMachine(n_gaussians=5)
-    gmm.means = load_array(
-        resource_filename("bob.learn.em", "data/meansAfterKMeans.hdf5")
-    ).astype("float64")
-    gmm.variances = load_array(
-        resource_filename("bob.learn.em", "data/variancesAfterKMeans.hdf5")
-    ).astype("float64")
-    gmm.weights = np.exp(
-        load_array(
-            resource_filename("bob.learn.em", "data/weightsAfterKMeans.hdf5")
-        ).astype("float64")
-    )
-
-    threshold = 0.001
-    gmm.variance_thresholds = threshold
-
-    # Initialize ML Trainer
-    gmm.mean_var_update_threshold = 0.001
-    gmm.max_fitting_steps = 25
-    gmm.convergence_threshold = 0.000001
-    gmm.update_means = True
-    gmm.update_variances = True
-    gmm.update_weights = True
-
-    # Run ML
-    gmm = gmm.fit(ar)
     # Test results
     # Load torch3vision reference
     meansML_ref = load_array(
@@ -852,10 +849,42 @@ def test_gmm_ML_2():
         resource_filename("bob.learn.em", "data/weightsAfterML.hdf5")
     )
 
-    # Compare to current results
-    np.testing.assert_allclose(gmm.means, meansML_ref, atol=3e-3)
-    np.testing.assert_allclose(gmm.variances, variancesML_ref, atol=3e-3)
-    np.testing.assert_allclose(gmm.weights, weightsML_ref, atol=1e-4)
+    for transform in (to_numpy, to_dask_array):
+        ar = transform(ar)
+        # Initialize GMMMachine
+        gmm = GMMMachine(n_gaussians=5)
+        gmm.means = load_array(
+            resource_filename("bob.learn.em", "data/meansAfterKMeans.hdf5")
+        ).astype("float64")
+        gmm.variances = load_array(
+            resource_filename("bob.learn.em", "data/variancesAfterKMeans.hdf5")
+        ).astype("float64")
+        gmm.weights = np.exp(
+            load_array(
+                resource_filename(
+                    "bob.learn.em", "data/weightsAfterKMeans.hdf5"
+                )
+            ).astype("float64")
+        )
+
+        threshold = 0.001
+        gmm.variance_thresholds = threshold
+
+        # Initialize ML Trainer
+        gmm.mean_var_update_threshold = 0.001
+        gmm.max_fitting_steps = 25
+        gmm.convergence_threshold = 0.000001
+        gmm.update_means = True
+        gmm.update_variances = True
+        gmm.update_weights = True
+
+        # Run ML
+        gmm = gmm.fit(ar)
+
+        # Compare to current results
+        np.testing.assert_allclose(gmm.means, meansML_ref, atol=3e-3)
+        np.testing.assert_allclose(gmm.variances, variancesML_ref, atol=3e-3)
+        np.testing.assert_allclose(gmm.weights, weightsML_ref, atol=1e-4)
 
 
 def test_gmm_MAP_1():
@@ -890,8 +919,6 @@ def test_gmm_MAP_1():
     gmm.update_variances = False
     gmm.update_weights = False
 
-    gmm = gmm.fit(ar)
-
     # Generate reference
     # gmm.save(HDF5File(resource_filename("bob.learn.em", "data/gmm_MAP.hdf5"), "w"))
 
@@ -899,9 +926,15 @@ def test_gmm_MAP_1():
         HDF5File(resource_filename("bob.learn.em", "data/gmm_MAP.hdf5"), "r")
     )
 
-    np.testing.assert_almost_equal(gmm.means, gmm_ref.means, decimal=3)
-    np.testing.assert_almost_equal(gmm.variances, gmm_ref.variances, decimal=3)
-    np.testing.assert_almost_equal(gmm.weights, gmm_ref.weights, decimal=3)
+    for transform in (to_numpy, to_dask_array):
+        ar = transform(ar)
+        gmm = gmm.fit(ar)
+
+        np.testing.assert_almost_equal(gmm.means, gmm_ref.means, decimal=3)
+        np.testing.assert_almost_equal(
+            gmm.variances, gmm_ref.variances, decimal=3
+        )
+        np.testing.assert_almost_equal(gmm.weights, gmm_ref.weights, decimal=3)
 
 
 def test_gmm_MAP_2():
@@ -934,14 +967,16 @@ def test_gmm_MAP_2():
     gmm_adapted.variances = variances
     gmm_adapted.weights = weights
 
-    gmm_adapted = gmm_adapted.fit(data)
-
     new_means = load_array(
         resource_filename("bob.learn.em", "data/new_adapted_mean.hdf5")
     )
 
-    # Compare to matlab reference
-    np.testing.assert_allclose(new_means.T, gmm_adapted.means, rtol=1e-4)
+    for transform in (to_numpy, to_dask_array):
+        data = transform(data)
+        gmm_adapted = gmm_adapted.fit(data)
+
+        # Compare to matlab reference
+        np.testing.assert_allclose(new_means.T, gmm_adapted.means, rtol=1e-4)
 
 
 def test_gmm_MAP_3():
@@ -981,9 +1016,6 @@ def test_gmm_MAP_3():
     )
     gmm.variance_thresholds = threshold
 
-    # Train
-    gmm = gmm.fit(ar)
-
     # Test results
     # Load torch3vision reference
     meansMAP_ref = load_array(
@@ -996,13 +1028,18 @@ def test_gmm_MAP_3():
         resource_filename("bob.learn.em", "data/weightsAfterMAP.hdf5")
     )
 
-    # Compare to current results
-    # Gaps are quite large. This might be explained by the fact that there is no
-    # adaptation of a given Gaussian in torch3 when the corresponding responsibilities
-    # are below the responsibilities threshold
-    np.testing.assert_allclose(gmm.means, meansMAP_ref, atol=2e-1)
-    np.testing.assert_allclose(gmm.variances, variancesMAP_ref, atol=1e-4)
-    np.testing.assert_allclose(gmm.weights, weightsMAP_ref, atol=1e-4)
+    for transform in (to_numpy, to_dask_array):
+        ar = transform(ar)
+        # Train
+        gmm = gmm.fit(ar)
+
+        # Compare to current results
+        # Gaps are quite large. This might be explained by the fact that there is no
+        # adaptation of a given Gaussian in torch3 when the corresponding responsibilities
+        # are below the responsibilities threshold
+        np.testing.assert_allclose(gmm.means, meansMAP_ref, atol=2e-1)
+        np.testing.assert_allclose(gmm.variances, variancesMAP_ref, atol=1e-4)
+        np.testing.assert_allclose(gmm.weights, weightsMAP_ref, atol=1e-4)
 
 
 def test_gmm_test():
@@ -1028,126 +1065,10 @@ def test_gmm_test():
 
     # Test against the model
     score_mean_ref = -1.50379e06
-    score = gmm.log_likelihood(ar).sum()
-    score /= len(ar)
-
-    # Compare current results to torch3vision
-    assert abs(score - score_mean_ref) / score_mean_ref < 1e-4
-
-
-def test_gmm_ML_dask():
-    # Trains a GMMMachine with dask array data; compares to a reference
-
-    ar = da.array(
-        load_array(
-            resource_filename("bob.learn.em", "data/dataNormalized.hdf5")
-        )
-    )
-
-    # Initialize GMMMachine
-    gmm = GMMMachine(n_gaussians=5)
-    gmm.means = load_array(
-        resource_filename("bob.learn.em", "data/meansAfterKMeans.hdf5")
-    ).astype("float64")
-    gmm.variances = load_array(
-        resource_filename("bob.learn.em", "data/variancesAfterKMeans.hdf5")
-    ).astype("float64")
-    gmm.weights = np.exp(
-        load_array(
-            resource_filename("bob.learn.em", "data/weightsAfterKMeans.hdf5")
-        ).astype("float64")
-    )
-
-    threshold = 0.001
-    gmm.variance_thresholds = threshold
-
-    # Initialize ML Trainer
-    gmm.mean_var_update_threshold = 0.001
-    gmm.max_fitting_steps = 25
-    gmm.convergence_threshold = 0.00001
-    gmm.update_means = True
-    gmm.update_variances = True
-    gmm.update_weights = True
-
-    # Run ML
-    gmm.fit(ar)
-
-    # Test results
-    # Load torch3vision reference
-    meansML_ref = load_array(
-        resource_filename("bob.learn.em", "data/meansAfterML.hdf5")
-    )
-    variancesML_ref = load_array(
-        resource_filename("bob.learn.em", "data/variancesAfterML.hdf5")
-    )
-    weightsML_ref = load_array(
-        resource_filename("bob.learn.em", "data/weightsAfterML.hdf5")
-    )
-
-    # Compare to current results
-    np.testing.assert_allclose(gmm.means, meansML_ref, atol=3e-3)
-    np.testing.assert_allclose(gmm.variances, variancesML_ref, atol=3e-3)
-    np.testing.assert_allclose(gmm.weights, weightsML_ref, atol=1e-4)
-
-
-def test_gmm_MAP_dask():
-    # Test a GMMMachine for MAP with a dask array as data.
-    ar = da.array(
-        load_array(resource_filename("bob.learn.em", "data/dataforMAP.hdf5"))
-    )
-
-    # Initialize GMMMachine
-    n_gaussians = 5
-    prior_gmm = GMMMachine(n_gaussians)
-    prior_gmm.means = load_array(
-        resource_filename("bob.learn.em", "data/meansAfterML.hdf5")
-    )
-    prior_gmm.variances = load_array(
-        resource_filename("bob.learn.em", "data/variancesAfterML.hdf5")
-    )
-    prior_gmm.weights = load_array(
-        resource_filename("bob.learn.em", "data/weightsAfterML.hdf5")
-    )
-
-    threshold = 0.001
-    prior_gmm.variance_thresholds = threshold
-
-    # Initialize MAP Trainer
-    prior = 0.001
-    accuracy = 0.00001
-    gmm = GMMMachine(
-        n_gaussians,
-        trainer="map",
-        ubm=prior_gmm,
-        convergence_threshold=prior,
-        max_fitting_steps=1,
-        update_means=True,
-        update_variances=False,
-        update_weights=False,
-        mean_var_update_threshold=accuracy,
-        map_relevance_factor=None,
-    )
-    gmm.variance_thresholds = threshold
-
-    # Train
-    gmm = gmm.fit(ar)
-
-    # Test results
-    # Load torch3vision reference
-    meansMAP_ref = load_array(
-        resource_filename("bob.learn.em", "data/meansAfterMAP.hdf5")
-    )
-    variancesMAP_ref = load_array(
-        resource_filename("bob.learn.em", "data/variancesAfterMAP.hdf5")
-    )
-    weightsMAP_ref = load_array(
-        resource_filename("bob.learn.em", "data/weightsAfterMAP.hdf5")
-    )
+    for transform in (to_numpy, to_dask_array):
+        ar = transform(ar)
+        score = gmm.log_likelihood(ar).sum()
+        score /= len(ar)
 
-    # Compare to current results
-    # Gaps are quite large. This might be explained by the fact that there is no
-    # adaptation of a given Gaussian in torch3 when the corresponding responsibilities
-    # are below the responsibilities threshold
-    np.testing.assert_allclose(gmm.means, meansMAP_ref, atol=2e-1)
-    np.testing.assert_allclose(gmm.variances, variancesMAP_ref, atol=1e-4)
-    np.testing.assert_allclose(gmm.weights, weightsMAP_ref, atol=1e-4)
+        # Compare current results to torch3vision
+        assert abs(score - score_mean_ref) / score_mean_ref < 1e-4
diff --git a/bob/learn/em/test/test_kmeans.py b/bob/learn/em/test/test_kmeans.py
index 8edc205..613a562 100644
--- a/bob/learn/em/test/test_kmeans.py
+++ b/bob/learn/em/test/test_kmeans.py
@@ -33,7 +33,7 @@ def to_dask_array(*args):
     for x in args:
         x = np.asarray(x)
         chunks = list(x.shape)
-        chunks[0] //= 2
+        chunks[0] = int(np.ceil(chunks[0] / 2))
         result.append(da.from_array(x, chunks=chunks))
     if len(result) == 1:
         return result[0]
-- 
GitLab