From 8bf560b0bc8f17b76707c91f6cb7aa41e2e8fa4d Mon Sep 17 00:00:00 2001
From: Yannick DAYER <yannick.dayer@idiap.ch>
Date: Tue, 7 Dec 2021 18:18:05 +0100
Subject: [PATCH] Added a k-means test for machine parameters

---
 bob/learn/em/cluster/k_means.py  |  4 +++-
 bob/learn/em/test/test_kmeans.py | 41 ++++++++++++++++++++++++--------
 2 files changed, 34 insertions(+), 11 deletions(-)

diff --git a/bob/learn/em/cluster/k_means.py b/bob/learn/em/cluster/k_means.py
index 72da1f1..7df4074 100644
--- a/bob/learn/em/cluster/k_means.py
+++ b/bob/learn/em/cluster/k_means.py
@@ -156,8 +156,10 @@ class KMeansMachine(BaseEstimator):
         weights_count = np.bincount(closest_centroid_indices, minlength=n_cluster)
         weights = weights_count / weights_count.sum()
 
-        # Accumulate
+        # FIX for `too many indices for array` error if using `np.eye(n_cluster)` alone:
         dask_compatible_eye = np.eye(n_cluster) * np.array(1, like=data)
+
+        # Accumulate
         means_sum = np.sum(
             dask_compatible_eye[closest_centroid_indices][:, :, None] * data[:, None],
             axis=0,
diff --git a/bob/learn/em/test/test_kmeans.py b/bob/learn/em/test/test_kmeans.py
index 426b3e5..c39bd4b 100644
--- a/bob/learn/em/test/test_kmeans.py
+++ b/bob/learn/em/test/test_kmeans.py
@@ -6,6 +6,8 @@
 # Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
 
 """Tests the KMeans machine
+
+Tries each test with a numpy array and the equivalent dask array.
 """
 
 import dask.array as da
@@ -46,14 +48,14 @@ def test_KMeansMachine():
         km.centroids_ = means
 
         # Distance and closest mean
-        np.testing.assert_almost_equal(km.transform(mean)[0], 1)
-        np.testing.assert_almost_equal(km.transform(mean)[1], 6)
+        np.testing.assert_almost_equal(km.transform(mean)[0], 1, decimal=10)
+        np.testing.assert_almost_equal(km.transform(mean)[1], 6, decimal=10)
 
         (index, dist) = km.get_closest_centroid(mean)
 
         assert index == 0, index
-        np.testing.assert_almost_equal(dist, 1.0)
-        np.testing.assert_almost_equal(km.get_min_distance(mean), 1)
+        np.testing.assert_almost_equal(dist, 1.0, decimal=10)
+        np.testing.assert_almost_equal(km.get_min_distance(mean), 1, decimal=10)
 
 
 def test_KMeansMachine_var_and_weight():
@@ -68,11 +70,8 @@ def test_KMeansMachine_var_and_weight():
         variances_result = np.array([[0.01, 1.0], [0.01555556, 0.00888889]])
         weights_result = np.array([0.4, 0.6])
 
-        np.testing.assert_almost_equal(variances, variances_result)
-        np.testing.assert_almost_equal(weights, weights_result)
-
-
-np.set_printoptions(precision=9)
+        np.testing.assert_almost_equal(variances, variances_result, decimal=7)
+        np.testing.assert_equal(weights, weights_result)
 
 
 def test_kmeans_fit():
@@ -89,7 +88,7 @@ def test_kmeans_fit():
             [-1.07173464, -1.06200356, -1.00724920],
             [0.99479125, 0.99665564, 0.97689017],
         ]
-        np.testing.assert_almost_equal(centroids, expected)
+        np.testing.assert_almost_equal(centroids, expected, decimal=7)
 
 
 def test_kmeans_fit_init_pp():
@@ -123,3 +122,25 @@ def test_kmeans_fit_init_random():
             [0.99529015, 0.99570570, 0.97580858],
         ]
         np.testing.assert_almost_equal(centroids, expected, decimal=7)
+
+def test_kmeans_parameters():
+    np.random.seed(0)
+    data1 = np.random.normal(loc=1, size=(2000, 3))
+    data2 = np.random.normal(loc=-1, size=(2000, 3))
+    data = np.concatenate([data1, data2], axis=0)
+    for transform in (to_numpy, to_dask_array):
+        data = transform(data)
+        machine = KMeansMachine(
+            n_clusters=2,
+            init_method="k-means||",
+            convergence_threshold=1e-5,
+            max_iter=5,
+            random_state=0,
+            init_max_iter=5,
+        ).fit(data)
+        centroids = machine.centroids_[np.argsort(machine.centroids_[:, 0])]
+        expected = [
+            [-1.07173464, -1.06200356, -1.00724920],
+            [0.99479125, 0.99665564, 0.97689017],
+        ]
+        np.testing.assert_almost_equal(centroids, expected, decimal=7)
\ No newline at end of file
-- 
GitLab