Skip to content
Snippets Groups Projects
Verified Commit 1921b8ad authored by Yannick DAYER's avatar Yannick DAYER
Browse files

tests(kmeans): adapt results to last changes.

parent dbe31f76
No related branches found
No related tags found
No related merge requests found
...@@ -94,27 +94,26 @@ def test_KMeansMachine_var_and_weight(): ...@@ -94,27 +94,26 @@ def test_KMeansMachine_var_and_weight():
np.testing.assert_equal(weights, weights_result) np.testing.assert_equal(weights, weights_result)
def test_kmeans_fit(): def test_kmeans_fit_parallel():
np.random.seed(0) rs_gen = np.random.RandomState(0)
data1 = np.random.normal(loc=1, size=(2000, 3)) data1 = rs_gen.normal(loc=1, size=(2000, 3))
data2 = np.random.normal(loc=-1, size=(2000, 3)) data2 = rs_gen.normal(loc=-1, size=(2000, 3))
print(data1.min(), data1.max()) data = np.vstack([data1, data2])
print(data2.min(), data2.max())
data = np.concatenate([data1, data2], axis=0)
for transform in (to_numpy, to_dask_array): for transform in (to_numpy, to_dask_array):
data = transform(data) data = transform(data)
machine = KMeansMachine(2, random_state=0).fit(data) machine = KMeansMachine(2, random_state=0).fit(data)
centroids = machine.centroids_[np.argsort(machine.centroids_[:, 0])] centroids = machine.centroids_[np.argsort(machine.centroids_[:, 0])]
expected = [ expected = [
[-1.07173464, -1.06200356, -1.00724920], [-1.07173464, -1.06200356, -1.0072492],
[0.99479125, 0.99665564, 0.97689017], [0.99479125, 0.99665564, 0.97689017],
] ]
np.testing.assert_almost_equal(centroids, expected, decimal=7) np.testing.assert_almost_equal(centroids, expected, decimal=7)
# Early stop # Early stop
machine = KMeansMachine(2, max_iter=2) machine = KMeansMachine(2, max_iter=2, random_state=0)
machine.fit(data) machine.fit(data)
np.testing.assert_almost_equal(centroids, expected, decimal=7)
def test_kmeans_fit_init_pp(): def test_kmeans_fit_init_pp():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment