From 6bafc7c7395e321aecd0c520dd538bd7b77f2001 Mon Sep 17 00:00:00 2001
From: Yannick DAYER <yannick.dayer@idiap.ch>
Date: Mon, 13 Dec 2021 10:44:18 +0100
Subject: [PATCH] Added kmeans tests

---
 bob/learn/em/cluster/k_means.py  | 17 ++++----
 bob/learn/em/test/test_kmeans.py | 68 +++++++++++++++++++++++++++-----
 2 files changed, 68 insertions(+), 17 deletions(-)

diff --git a/bob/learn/em/cluster/k_means.py b/bob/learn/em/cluster/k_means.py
index 7df4074..abdf6a6 100644
--- a/bob/learn/em/cluster/k_means.py
+++ b/bob/learn/em/cluster/k_means.py
@@ -78,6 +78,7 @@ class KMeansMachine(BaseEstimator):
         self.average_min_distance = np.inf
         self.zeroeth_order_statistics = None
         self.first_order_statistics = None
+        self.centroids_ = None
 
     def get_centroids_distance(self, x: np.ndarray) -> np.ndarray:
         """Returns the distance values between x and each cluster's centroid.
@@ -117,18 +118,18 @@ class KMeansMachine(BaseEstimator):
         return np.min(self.get_centroids_distance(x), axis=0)
 
     def __eq__(self, obj) -> bool:
-        if hasattr(self, "centroids_") and hasattr(obj, "centroids_"):
-            return np.allclose(self.centroids_, obj.centroids_, rtol=0, atol=0)
-        else:
-            raise ValueError("centroids_ was not set. You should call 'fit' first.")
+        return self.is_similar_to(obj, r_epsilon=0, a_epsilon=0)
 
     def is_similar_to(self, obj, r_epsilon=1e-05, a_epsilon=1e-08) -> bool:
-        if hasattr(self, "centroids_") and hasattr(obj, "centroids_"):
+        if self.centroids_ is not None and obj.centroids_ is not None:
             return np.allclose(
                 self.centroids_, obj.centroids_, rtol=r_epsilon, atol=a_epsilon
             )
         else:
-            raise ValueError("centroids_ was not set. You should call 'fit' first.")
+            logger.warning(
+                "KMeansMachine `centroids_` was not set. You should call 'fit' first."
+            )
+            return False
 
     def get_variances_and_weights_for_each_cluster(self, data: np.ndarray):
         """Returns the clusters variance and weight for data clustered by the machine.
@@ -255,8 +256,8 @@ class KMeansMachine(BaseEstimator):
 
     def partial_fit(self, X, y=None):
         """Applies one e-m step of k-means on the data."""
-        if not hasattr(self, "centroids_"):
-            logger.debug(f"First call to 'partial_fit'. Initializing...")
+        if self.centroids_ is None:
+            logger.debug("First call to 'partial_fit'. Initializing...")
             self.initialize(data=X)
 
         self.e_step(data=X)
diff --git a/bob/learn/em/test/test_kmeans.py b/bob/learn/em/test/test_kmeans.py
index c39bd4b..cbf5f09 100644
--- a/bob/learn/em/test/test_kmeans.py
+++ b/bob/learn/em/test/test_kmeans.py
@@ -10,6 +10,8 @@
 Tries each test with a numpy array and the equivalent dask array.
 """
 
+import copy
+
 import dask.array as da
 import numpy as np
 
@@ -38,24 +40,47 @@ def test_KMeansMachine():
     # Test a KMeansMachine
 
     means = np.array([[3, 70, 0], [4, 72, 0]], "float64")
-    mean = np.array([3, 70, 1], "float64")
+    test_val = np.array([3, 70, 1], "float64")
+    test_arr = np.array([[3, 70, 1],[5, 72, 0]], "float64")
 
     for transform in (to_numpy, to_dask_array):
-        means, mean = transform(means, mean)
+        means, test_val, test_arr = transform(means, test_val, test_arr)
 
         # Initializes a KMeansMachine
         km = KMeansMachine(2)
         km.centroids_ = means
 
         # Distance and closest mean
-        np.testing.assert_almost_equal(km.transform(mean)[0], 1, decimal=10)
-        np.testing.assert_almost_equal(km.transform(mean)[1], 6, decimal=10)
+        np.testing.assert_equal(km.transform(test_val)[0], np.array([1]))
+        np.testing.assert_equal(km.transform(test_val)[1], np.array([6]))
+
+        (index, dist) = km.get_closest_centroid(test_val)
+        assert index == 0
+        np.testing.assert_equal(dist, np.array([[1.0]]))
+
+        (indices, dists) = km.get_closest_centroid(test_arr)
+        np.testing.assert_equal(indices, np.array([0,1]))
+        np.testing.assert_equal(dists, np.array([[1,8],[6,1]]))
 
-        (index, dist) = km.get_closest_centroid(mean)
+        index = km.predict(test_val)
+        assert index == 0
 
-        assert index == 0, index
-        np.testing.assert_almost_equal(dist, 1.0, decimal=10)
-        np.testing.assert_almost_equal(km.get_min_distance(mean), 1, decimal=10)
+        indices = km.predict(test_arr)
+        np.testing.assert_equal(indices, np.array([0,1]))
+
+        np.testing.assert_equal(km.get_min_distance(test_val), np.array([1]))
+        np.testing.assert_equal(km.get_min_distance(test_arr), np.array([1,1]))
+
+        # Check __eq__ and is_similar_to
+        km2 = KMeansMachine(2)
+        assert km != km2
+        assert not km.is_similar_to(km2)
+        km2 = copy.deepcopy(km)
+        assert km == km2
+        assert km.is_similar_to(km2)
+        km2.centroids_[0,0] += 1
+        assert km != km2
+        assert not km.is_similar_to(km2)
 
 
 def test_KMeansMachine_var_and_weight():
@@ -78,6 +103,8 @@ def test_kmeans_fit():
     np.random.seed(0)
     data1 = np.random.normal(loc=1, size=(2000, 3))
     data2 = np.random.normal(loc=-1, size=(2000, 3))
+    print(data1.min(), data1.max())
+    print(data2.min(), data2.max())
     data = np.concatenate([data1, data2], axis=0)
 
     for transform in (to_numpy, to_dask_array):
@@ -90,6 +117,29 @@ def test_kmeans_fit():
         ]
         np.testing.assert_almost_equal(centroids, expected, decimal=7)
 
+        # Early stop
+        machine = KMeansMachine(2, max_iter=2)
+        machine.fit(data)
+
+
+def test_kmeans_fit_partial():
+    np.random.seed(0)
+    data1 = np.random.normal(loc=1, size=(2000, 3))
+    data2 = np.random.normal(loc=-1, size=(2000, 3))
+    data = np.concatenate([data1, data2], axis=0)
+
+    for transform in (to_numpy, to_dask_array):
+        data = transform(data)
+        machine = KMeansMachine(2, random_state=0)
+        for _ in range(20):
+            machine.partial_fit(data)
+        centroids = machine.centroids_[np.argsort(machine.centroids_[:, 0])]
+        expected = [
+            [-1.07173464, -1.06200356, -1.00724920],
+            [0.99479125, 0.99665564, 0.97689017],
+        ]
+        np.testing.assert_almost_equal(centroids, expected, decimal=7)
+
 
 def test_kmeans_fit_init_pp():
     np.random.seed(0)
@@ -143,4 +193,4 @@ def test_kmeans_parameters():
             [-1.07173464, -1.06200356, -1.00724920],
             [0.99479125, 0.99665564, 0.97689017],
         ]
-        np.testing.assert_almost_equal(centroids, expected, decimal=7)
\ No newline at end of file
+        np.testing.assert_almost_equal(centroids, expected, decimal=7)
-- 
GitLab