diff --git a/tests/test_gmm.py b/tests/test_gmm.py index a7e67c257f50dfaa7421b03657eee7e08b929c34..689415e20a36a16e956e3c462304646a9a15b3be 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -242,21 +242,21 @@ def test_GMMMachine(): gmm.save(HDF5File(filename, "w")) # Using from_hdf5 gmm1 = GMMMachine.from_hdf5(HDF5File(filename, "r")) - assert type(gmm1.n_gaussians) is np.int64 - assert type(gmm1.update_means) is np.bool_ - assert type(gmm1.update_variances) is np.bool_ - assert type(gmm1.update_weights) is np.bool_ - assert type(gmm1.trainer) is str + assert isinstance(gmm1.n_gaussians, np.int64) + assert isinstance(gmm1.update_means, np.bool_) + assert isinstance(gmm1.update_variances, np.bool_) + assert isinstance(gmm1.update_weights, np.bool_) + assert isinstance(gmm1.trainer, str) assert gmm1.ubm is None assert_gmm_equal(gmm, gmm1) # Using load gmm1 = GMMMachine(n_gaussians=gmm.n_gaussians) gmm1.load(HDF5File(filename, "r")) - assert type(gmm1.n_gaussians) is np.int64 - assert type(gmm1.update_means) is np.bool_ - assert type(gmm1.update_variances) is np.bool_ - assert type(gmm1.update_weights) is np.bool_ - assert type(gmm1.trainer) is str + assert isinstance(gmm1.n_gaussians, np.int64) + assert isinstance(gmm1.update_means, np.bool_) + assert isinstance(gmm1.update_variances, np.bool_) + assert isinstance(gmm1.update_weights, np.bool_) + assert isinstance(gmm1.trainer, str) assert gmm1.ubm is None assert_gmm_equal(gmm, gmm1) @@ -525,18 +525,18 @@ def test_gmm_kmeans_parallel_init(): for transform in (to_numpy, to_dask_array): data = transform(data) machine = machine.fit(data) + + sorted_indices = np.argsort(machine.means[:, 0]) + means = machine.means[sorted_indices] + variances = machine.variances[sorted_indices] expected_means = np.array( - [[1.25, 1.25], [-1.25, 0.25], [2.25, 2.25]] + [[-1.25, 0.25], [1.25, 1.25], [2.25, 2.25]] ) expected_variances = np.array( [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]] ) - np.testing.assert_almost_equal( - machine.means, expected_means, decimal=3 - ) - np.testing.assert_almost_equal( - machine.variances, expected_variances - ) + np.testing.assert_almost_equal(means, expected_means, decimal=3) + np.testing.assert_almost_equal(variances, expected_variances) def test_likelihood():