Skip to content
Snippets Groups Projects
Commit 8bf560b0 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Added a k-means test for machine parameters

parent fb7b686b
No related branches found
No related tags found
2 merge requests!42GMM implementation in Python,!40Transition to a pure python implementation
Pipeline #56830 passed
......@@ -156,8 +156,10 @@ class KMeansMachine(BaseEstimator):
weights_count = np.bincount(closest_centroid_indices, minlength=n_cluster)
weights = weights_count / weights_count.sum()
# Accumulate
# FIX for `too many indices for array` error if using `np.eye(n_cluster)` alone:
dask_compatible_eye = np.eye(n_cluster) * np.array(1, like=data)
# Accumulate
means_sum = np.sum(
dask_compatible_eye[closest_centroid_indices][:, :, None] * data[:, None],
axis=0,
......
......@@ -6,6 +6,8 @@
# Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
"""Tests the KMeans machine
Tries each test with a numpy array and the equivalent dask array.
"""
import dask.array as da
......@@ -46,14 +48,14 @@ def test_KMeansMachine():
km.centroids_ = means
# Distance and closest mean
np.testing.assert_almost_equal(km.transform(mean)[0], 1)
np.testing.assert_almost_equal(km.transform(mean)[1], 6)
np.testing.assert_almost_equal(km.transform(mean)[0], 1, decimal=10)
np.testing.assert_almost_equal(km.transform(mean)[1], 6, decimal=10)
(index, dist) = km.get_closest_centroid(mean)
assert index == 0, index
np.testing.assert_almost_equal(dist, 1.0)
np.testing.assert_almost_equal(km.get_min_distance(mean), 1)
np.testing.assert_almost_equal(dist, 1.0, decimal=10)
np.testing.assert_almost_equal(km.get_min_distance(mean), 1, decimal=10)
def test_KMeansMachine_var_and_weight():
......@@ -68,11 +70,8 @@ def test_KMeansMachine_var_and_weight():
variances_result = np.array([[0.01, 1.0], [0.01555556, 0.00888889]])
weights_result = np.array([0.4, 0.6])
np.testing.assert_almost_equal(variances, variances_result)
np.testing.assert_almost_equal(weights, weights_result)
np.set_printoptions(precision=9)
np.testing.assert_almost_equal(variances, variances_result, decimal=7)
np.testing.assert_equal(weights, weights_result)
def test_kmeans_fit():
......@@ -89,7 +88,7 @@ def test_kmeans_fit():
[-1.07173464, -1.06200356, -1.00724920],
[0.99479125, 0.99665564, 0.97689017],
]
np.testing.assert_almost_equal(centroids, expected)
np.testing.assert_almost_equal(centroids, expected, decimal=7)
def test_kmeans_fit_init_pp():
......@@ -123,3 +122,25 @@ def test_kmeans_fit_init_random():
[0.99529015, 0.99570570, 0.97580858],
]
np.testing.assert_almost_equal(centroids, expected, decimal=7)
def test_kmeans_parameters():
np.random.seed(0)
data1 = np.random.normal(loc=1, size=(2000, 3))
data2 = np.random.normal(loc=-1, size=(2000, 3))
data = np.concatenate([data1, data2], axis=0)
for transform in (to_numpy, to_dask_array):
data = transform(data)
machine = KMeansMachine(
n_clusters=2,
init_method="k-means||",
convergence_threshold=1e-5,
max_iter=5,
random_state=0,
init_max_iter=5,
).fit(data)
centroids = machine.centroids_[np.argsort(machine.centroids_[:, 0])]
expected = [
[-1.07173464, -1.06200356, -1.00724920],
[0.99479125, 0.99665564, 0.97689017],
]
np.testing.assert_almost_equal(centroids, expected, decimal=7)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment