diff --git a/tests/test_kmeans.py b/tests/test_kmeans.py index 613a562bf864b4d66b31ae5a410fc6ad6298e164..15ccfe10fd8525805bf8c0ca2295420d848dfa08 100644 --- a/tests/test_kmeans.py +++ b/tests/test_kmeans.py @@ -94,27 +94,26 @@ def test_KMeansMachine_var_and_weight(): np.testing.assert_equal(weights, weights_result) -def test_kmeans_fit(): - np.random.seed(0) - data1 = np.random.normal(loc=1, size=(2000, 3)) - data2 = np.random.normal(loc=-1, size=(2000, 3)) - print(data1.min(), data1.max()) - print(data2.min(), data2.max()) - data = np.concatenate([data1, data2], axis=0) +def test_kmeans_fit_parallel(): + rs_gen = np.random.RandomState(0) + data1 = rs_gen.normal(loc=1, size=(2000, 3)) + data2 = rs_gen.normal(loc=-1, size=(2000, 3)) + data = np.vstack([data1, data2]) for transform in (to_numpy, to_dask_array): data = transform(data) machine = KMeansMachine(2, random_state=0).fit(data) centroids = machine.centroids_[np.argsort(machine.centroids_[:, 0])] expected = [ - [-1.07173464, -1.06200356, -1.00724920], + [-1.07173464, -1.06200356, -1.0072492], [0.99479125, 0.99665564, 0.97689017], ] np.testing.assert_almost_equal(centroids, expected, decimal=7) # Early stop - machine = KMeansMachine(2, max_iter=2) + machine = KMeansMachine(2, max_iter=2, random_state=0) machine.fit(data) + np.testing.assert_almost_equal(centroids, expected, decimal=7) def test_kmeans_fit_init_pp():