From 112a753a777ff2b8506531446ecf0dfadbe83af9 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Mon, 12 Jan 2015 15:00:06 +0100
Subject: [PATCH] More method binded for the KMeansMachine class

---
 bob/learn/misc/kmeans_machine.cpp | 155 +++++++++++++++++++++++++++++-
 bob/learn/misc/test_kmeans.py     |  47 ++++++---
 2 files changed, 186 insertions(+), 16 deletions(-)

diff --git a/bob/learn/misc/kmeans_machine.cpp b/bob/learn/misc/kmeans_machine.cpp
index 273705b..39e467d 100644
--- a/bob/learn/misc/kmeans_machine.cpp
+++ b/bob/learn/misc/kmeans_machine.cpp
@@ -366,7 +366,7 @@ static auto get_mean = bob::extension::FunctionDoc(
   ".. note:: An exception is thrown if i is out of range.", 
   true
 )
-.add_prototype("i","mean index")
+.add_prototype("i")
 .add_parameter("i", "int", "Index of the mean")
 .add_return("mean","array_like <float, 1D>","Mean array");
 static PyObject* PyBobLearnMiscKMeansMachine_get_mean(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) {
@@ -379,10 +379,42 @@ static PyObject* PyBobLearnMiscKMeansMachine_get_mean(PyBobLearnMiscKMeansMachin
  
   return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getMean(i));
 
-  BOB_CATCH_MEMBER("cannot compute the likelihood", 0)
+  BOB_CATCH_MEMBER("cannot get the mean", 0)
+}
+
+
+/*** set_mean ***/
+static auto set_mean = bob::extension::FunctionDoc(
+  "set_mean",
+  "Set the i'th mean.",
+  ".. note:: An exception is thrown if i is out of range.", 
+  true
+)
+.add_prototype("i,mean")
+.add_parameter("i", "int", "Index of the mean")
+.add_parameter("mean", "array_like <float, 1D>", "Mean array");
+static PyObject* PyBobLearnMiscKMeansMachine_set_mean(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+  
+  char** kwlist = set_mean.kwlist(0);
+
+  int i = 0;
+  PyBlitzArrayObject* mean = 0;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "iO&", kwlist, &i, &PyBlitzArray_Converter, &mean)) Py_RETURN_NONE;
+  
+  //protects acquired resources through this scope
+  auto mean_ = make_safe(mean);
+
+  //setting the mean
+  self->cxx->setMean(i, *PyBlitzArrayCxx_AsBlitz<double,1>(mean));
+
+  BOB_CATCH_MEMBER("cannot set the mean", 0)
+  
+  Py_RETURN_NONE;
 }
 
 
+
 /*** get_distance_from_mean ***/
 static auto get_distance_from_mean = bob::extension::FunctionDoc(
   "get_distance_from_mean",
@@ -408,7 +440,6 @@ static PyObject* PyBobLearnMiscKMeansMachine_get_distance_from_mean(PyBobLearnMi
   //protects acquired resources through this scope
   auto input_ = make_safe(input);
 
-  //return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getMean(i));
   double output = self->cxx->getDistanceFromMean(*PyBlitzArrayCxx_AsBlitz<double,1>(input),i);
   return Py_BuildValue("d", output);
 
@@ -416,6 +447,99 @@ static PyObject* PyBobLearnMiscKMeansMachine_get_distance_from_mean(PyBobLearnMi
 }
 
 
+/*** get_closest_mean ***/
+static auto get_closest_mean = bob::extension::FunctionDoc(
+  "get_closest_mean",
+  "Calculate the index of the mean that is closest (in terms of square Euclidean distance) to the data sample, x.",
+  "",
+  true
+)
+.add_prototype("input","output")
+.add_parameter("input", "array_like <float, 1D>", "The data sample (feature vector)")
+.add_return("output", "(int, int)", "Tuple containing the closest mean and the minimum distance from the input");
+static PyObject* PyBobLearnMiscKMeansMachine_get_closest_mean(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+  
+  char** kwlist = get_closest_mean.kwlist(0);
+
+  PyBlitzArrayObject* input = 0;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBlitzArray_Converter, &input)) Py_RETURN_NONE;
+
+  //protects acquired resources through this scope
+  auto input_ = make_safe(input);
+
+  size_t closest_mean = 0;
+  double min_distance = -1;   
+  self->cxx->getClosestMean(*PyBlitzArrayCxx_AsBlitz<double,1>(input), closest_mean, min_distance);
+    
+  return Py_BuildValue("(i,d)", closest_mean, min_distance);
+
+  BOB_CATCH_MEMBER("cannot compute the closest mean", 0)
+}
+
+
+/*** get_min_distance ***/
+static auto get_min_distance = bob::extension::FunctionDoc(
+  "get_min_distance",
+  "Output the minimum (Square Euclidean) distance between the input and the closest mean ",
+  "",
+  true
+)
+.add_prototype("input","output")
+.add_parameter("input", "array_like <float, 1D>", "The data sample (feature vector)")
+.add_return("output", "double", "The minimum distance");
+static PyObject* PyBobLearnMiscKMeansMachine_get_min_distance(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+  
+  char** kwlist = get_min_distance.kwlist(0);
+
+  PyBlitzArrayObject* input = 0;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBlitzArray_Converter, &input)) Py_RETURN_NONE;
+
+  //protects acquired resources through this scope
+  auto input_ = make_safe(input);
+
+  double min_distance = 0;   
+  min_distance = self->cxx->getMinDistance(*PyBlitzArrayCxx_AsBlitz<double,1>(input));
+
+  return Py_BuildValue("d", min_distance);
+
+  BOB_CATCH_MEMBER("cannot compute the min distance", 0)
+}
+
+/**** get_variances_and_weights_for_each_cluster ***/
+static auto get_variances_and_weights_for_each_cluster = bob::extension::FunctionDoc(
+  "get_variances_and_weights_for_each_cluster",
+  "For each mean, find the subset of the samples that is closest to that mean, and calculate"
+  " 1) the variance of that subset (the cluster variance)" 
+  " 2) the proportion of the samples represented by that subset (the cluster weight)",
+  "",
+  true
+)
+.add_prototype("input","output")
+.add_parameter("input", "array_like <float, 2D>", "The data sample (feature vector)")
+.add_return("output", "(array_like <float, 2D>, array_like <float, 1D>)", "A tuple with the variances and the weights respectively");
+static PyObject* PyBobLearnMiscKMeansMachine_get_variances_and_weights_for_each_cluster(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+  
+  char** kwlist =  get_variances_and_weights_for_each_cluster.kwlist(0);
+
+  PyBlitzArrayObject* input = 0;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBlitzArray_Converter, &input)) Py_RETURN_NONE;
+
+  //protects acquired resources through this scope
+  auto input_ = make_safe(input);
+
+  blitz::Array<double,2> variances(self->cxx->getNMeans(),self->cxx->getNInputs());
+  blitz::Array<double,1> weights(self->cxx->getNMeans());
+  
+  self->cxx->getVariancesAndWeightsForEachCluster(*PyBlitzArrayCxx_AsBlitz<double,2>(input),variances,weights);
+
+  return Py_BuildValue("(O,O)",PyBlitzArrayCxx_AsConstNumpy(variances), PyBlitzArrayCxx_AsConstNumpy(weights));
+
+  BOB_CATCH_MEMBER("cannot compute the variances and weights for each cluster", 0)
+}
+
 
 
 static PyMethodDef PyBobLearnMiscKMeansMachine_methods[] = {
@@ -449,12 +573,37 @@ static PyMethodDef PyBobLearnMiscKMeansMachine_methods[] = {
     METH_VARARGS|METH_KEYWORDS,
     get_mean.doc()
   },  
+  {
+    set_mean.name(),
+    (PyCFunction)PyBobLearnMiscKMeansMachine_set_mean,
+    METH_VARARGS|METH_KEYWORDS,
+    set_mean.doc()
+  },  
   {
     get_distance_from_mean.name(),
     (PyCFunction)PyBobLearnMiscKMeansMachine_get_distance_from_mean,
     METH_VARARGS|METH_KEYWORDS,
     get_distance_from_mean.doc()
   },  
+  {
+    get_closest_mean.name(),
+    (PyCFunction)PyBobLearnMiscKMeansMachine_get_closest_mean,
+    METH_VARARGS|METH_KEYWORDS,
+    get_closest_mean.doc()
+  },  
+  {
+    get_min_distance.name(),
+    (PyCFunction)PyBobLearnMiscKMeansMachine_get_min_distance,
+    METH_VARARGS|METH_KEYWORDS,
+    get_min_distance.doc()
+  },  
+
+  {
+    get_variances_and_weights_for_each_cluster.name(),
+    (PyCFunction)PyBobLearnMiscKMeansMachine_get_variances_and_weights_for_each_cluster,
+    METH_VARARGS|METH_KEYWORDS,
+    get_variances_and_weights_for_each_cluster.doc()
+  },  
 
   {0} /* Sentinel */
 };
diff --git a/bob/learn/misc/test_kmeans.py b/bob/learn/misc/test_kmeans.py
index d55ed8b..69c3e6a 100644
--- a/bob/learn/misc/test_kmeans.py
+++ b/bob/learn/misc/test_kmeans.py
@@ -31,24 +31,19 @@ def test_KMeansMachine():
 
   # Sets and gets
   assert (km.means == means).all()
-  assert (km.get_mean(0) == means[0,:]).all()
+  assert (km.get_mean(0) == means[0,:]).all()  
   assert (km.get_mean(1) == means[1,:]).all()
-  #km.set_mean(0, mean)
-  #assert (km.get_mean(0) == mean).all()
+  km.set_mean(0, mean)
+  assert (km.get_mean(0) == mean).all()
 
   # Distance and closest mean
   eps = 1e-10
 
-  print mean
-  print km.means
-
-  print km.get_distance_from_mean(mean, 0)
-
-
-
   assert equals( km.get_distance_from_mean(mean, 0), 0, eps)
-  assert equals( km.get_distance_from_mean(mean, 1), 6, eps)
+  assert equals( km.get_distance_from_mean(mean, 1), 6, eps)  
+  
   (index, dist) = km.get_closest_mean(mean)
+  
   assert index == 0
   assert equals( dist, 0, eps)
   assert equals( km.get_min_distance(mean), 0, eps)
@@ -61,8 +56,7 @@ def test_KMeansMachine():
 
   # Resize
   km.resize(4,5)
-  assert km.dim_c == 4
-  assert km.dim_d == 5
+  assert km.shape == (4,5)
 
   # Copy constructor and comparison operators
   km.resize(2,3)
@@ -78,3 +72,30 @@ def test_KMeansMachine():
 
   # Clean-up
   os.unlink(filename)
+  
+  
+def test_KMeansMachine2():
+  print "Computing" 
+  kmeans             = bob.learn.misc.KMeansMachine(2,2)
+  kmeans.means       = numpy.array([[1.2,1.3],[0.2,-0.3]])
+
+  data               = numpy.array([
+                                  [1.,1],
+                                  [1.2, 3],
+                                  [0,0],
+                                  [0.3,0.2],
+                                  [0.2,0]
+                                 ])
+  print "Computing" 
+  variances, weights = kmeans.get_variances_and_weights_for_each_cluster(data)
+  print "Computed" 
+
+  variances_result = numpy.array([[ 0.01,1.],
+                                  [ 0.01555556 ,0.00888889]])
+  weights_result = numpy.array([ 0.4, 0.6])
+
+  assert True
+  
+  #assert weights_result   == weights
+  #assert variances_result == variances
+ 
-- 
GitLab