From 5ca0c664508aadf7104a89621f78f0fd7c15de6b Mon Sep 17 00:00:00 2001
From: Yannick DAYER <yannick.dayer@idiap.ch>
Date: Wed, 29 May 2024 15:11:17 +0200
Subject: [PATCH] tests: sort the returned gmm means for comparison.

---
 tests/test_gmm.py | 34 +++++++++++++++++-----------------
 1 file changed, 17 insertions(+), 17 deletions(-)

diff --git a/tests/test_gmm.py b/tests/test_gmm.py
index a7e67c2..689415e 100644
--- a/tests/test_gmm.py
+++ b/tests/test_gmm.py
@@ -242,21 +242,21 @@ def test_GMMMachine():
         gmm.save(HDF5File(filename, "w"))
         # Using from_hdf5
         gmm1 = GMMMachine.from_hdf5(HDF5File(filename, "r"))
-        assert type(gmm1.n_gaussians) is np.int64
-        assert type(gmm1.update_means) is np.bool_
-        assert type(gmm1.update_variances) is np.bool_
-        assert type(gmm1.update_weights) is np.bool_
-        assert type(gmm1.trainer) is str
+        assert isinstance(gmm1.n_gaussians, np.int64)
+        assert isinstance(gmm1.update_means, np.bool_)
+        assert isinstance(gmm1.update_variances, np.bool_)
+        assert isinstance(gmm1.update_weights, np.bool_)
+        assert isinstance(gmm1.trainer, str)
         assert gmm1.ubm is None
         assert_gmm_equal(gmm, gmm1)
         # Using load
         gmm1 = GMMMachine(n_gaussians=gmm.n_gaussians)
         gmm1.load(HDF5File(filename, "r"))
-        assert type(gmm1.n_gaussians) is np.int64
-        assert type(gmm1.update_means) is np.bool_
-        assert type(gmm1.update_variances) is np.bool_
-        assert type(gmm1.update_weights) is np.bool_
-        assert type(gmm1.trainer) is str
+        assert isinstance(gmm1.n_gaussians, np.int64)
+        assert isinstance(gmm1.update_means, np.bool_)
+        assert isinstance(gmm1.update_variances, np.bool_)
+        assert isinstance(gmm1.update_weights, np.bool_)
+        assert isinstance(gmm1.trainer, str)
         assert gmm1.ubm is None
         assert_gmm_equal(gmm, gmm1)
 
@@ -525,18 +525,18 @@ def test_gmm_kmeans_parallel_init():
         for transform in (to_numpy, to_dask_array):
             data = transform(data)
             machine = machine.fit(data)
+
+            sorted_indices = np.argsort(machine.means[:, 0])
+            means = machine.means[sorted_indices]
+            variances = machine.variances[sorted_indices]
             expected_means = np.array(
-                [[1.25, 1.25], [-1.25, 0.25], [2.25, 2.25]]
+                [[-1.25, 0.25], [1.25, 1.25], [2.25, 2.25]]
             )
             expected_variances = np.array(
                 [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]]
             )
-            np.testing.assert_almost_equal(
-                machine.means, expected_means, decimal=3
-            )
-            np.testing.assert_almost_equal(
-                machine.variances, expected_variances
-            )
+            np.testing.assert_almost_equal(means, expected_means, decimal=3)
+            np.testing.assert_almost_equal(variances, expected_variances)
 
 
 def test_likelihood():
-- 
GitLab