From 5ca0c664508aadf7104a89621f78f0fd7c15de6b Mon Sep 17 00:00:00 2001 From: Yannick DAYER <yannick.dayer@idiap.ch> Date: Wed, 29 May 2024 15:11:17 +0200 Subject: [PATCH] tests: sort the returned gmm means for comparison. --- tests/test_gmm.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/test_gmm.py b/tests/test_gmm.py index a7e67c2..689415e 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -242,21 +242,21 @@ def test_GMMMachine(): gmm.save(HDF5File(filename, "w")) # Using from_hdf5 gmm1 = GMMMachine.from_hdf5(HDF5File(filename, "r")) - assert type(gmm1.n_gaussians) is np.int64 - assert type(gmm1.update_means) is np.bool_ - assert type(gmm1.update_variances) is np.bool_ - assert type(gmm1.update_weights) is np.bool_ - assert type(gmm1.trainer) is str + assert isinstance(gmm1.n_gaussians, np.int64) + assert isinstance(gmm1.update_means, np.bool_) + assert isinstance(gmm1.update_variances, np.bool_) + assert isinstance(gmm1.update_weights, np.bool_) + assert isinstance(gmm1.trainer, str) assert gmm1.ubm is None assert_gmm_equal(gmm, gmm1) # Using load gmm1 = GMMMachine(n_gaussians=gmm.n_gaussians) gmm1.load(HDF5File(filename, "r")) - assert type(gmm1.n_gaussians) is np.int64 - assert type(gmm1.update_means) is np.bool_ - assert type(gmm1.update_variances) is np.bool_ - assert type(gmm1.update_weights) is np.bool_ - assert type(gmm1.trainer) is str + assert isinstance(gmm1.n_gaussians, np.int64) + assert isinstance(gmm1.update_means, np.bool_) + assert isinstance(gmm1.update_variances, np.bool_) + assert isinstance(gmm1.update_weights, np.bool_) + assert isinstance(gmm1.trainer, str) assert gmm1.ubm is None assert_gmm_equal(gmm, gmm1) @@ -525,18 +525,18 @@ def test_gmm_kmeans_parallel_init(): for transform in (to_numpy, to_dask_array): data = transform(data) machine = machine.fit(data) + + sorted_indices = np.argsort(machine.means[:, 0]) + means = machine.means[sorted_indices] + variances = machine.variances[sorted_indices] expected_means = np.array( - [[1.25, 1.25], [-1.25, 0.25], [2.25, 2.25]] + [[-1.25, 0.25], [1.25, 1.25], [2.25, 2.25]] ) expected_variances = np.array( [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]] ) - np.testing.assert_almost_equal( - machine.means, expected_means, decimal=3 - ) - np.testing.assert_almost_equal( - machine.variances, expected_variances - ) + np.testing.assert_almost_equal(means, expected_means, decimal=3) + np.testing.assert_almost_equal(variances, expected_variances) def test_likelihood(): -- GitLab