From 53243cdb74b4a73d0a1607ff133a7d397ce66235 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Mon, 12 Jan 2015 10:46:45 +0100
Subject: [PATCH] Binding and redesigning the KMeansMachine class

---
 bob/learn/misc/cpp/KMeansMachine.cpp          |   7 +-
 .../include/bob.learn.misc/KMeansMachine.h    |   5 +-
 bob/learn/misc/kmeans_machine.cpp             | 498 ++++++++++++++++++
 bob/learn/misc/main.cpp                       |   3 +-
 bob/learn/misc/main.h                         |  12 +
 bob/learn/misc/test_kmeans.py                 |  15 +-
 setup.py                                      |   3 +-
 7 files changed, 531 insertions(+), 12 deletions(-)
 create mode 100644 bob/learn/misc/kmeans_machine.cpp

diff --git a/bob/learn/misc/cpp/KMeansMachine.cpp b/bob/learn/misc/cpp/KMeansMachine.cpp
index d30cee4..7b4a5a2 100644
--- a/bob/learn/misc/cpp/KMeansMachine.cpp
+++ b/bob/learn/misc/cpp/KMeansMachine.cpp
@@ -110,15 +110,16 @@ void bob::learn::misc::KMeansMachine::setMean(const size_t i, const blitz::Array
   m_means(i,blitz::Range::all()) = mean;
 }
 
-void bob::learn::misc::KMeansMachine::getMean(const size_t i, blitz::Array<double,1> &mean) const
+const blitz::Array<double,1> bob::learn::misc::KMeansMachine::getMean(const size_t i) const
 {
   if(i>=m_n_means) {
     boost::format m("cannot get mean with index %lu: out of bounds [0,%lu[");
     m % i % m_n_means;
     throw std::runtime_error(m.str());
   }
-  bob::core::array::assertSameDimensionLength(mean.extent(0), m_means.extent(1));
-  mean = m_means(i,blitz::Range::all());
+
+  return m_means(i,blitz::Range::all());
+
 }
 
 double bob::learn::misc::KMeansMachine::getDistanceFromMean(const blitz::Array<double,1> &x,
diff --git a/bob/learn/misc/include/bob.learn.misc/KMeansMachine.h b/bob/learn/misc/include/bob.learn.misc/KMeansMachine.h
index 3f12213..5f8f5bf 100644
--- a/bob/learn/misc/include/bob.learn.misc/KMeansMachine.h
+++ b/bob/learn/misc/include/bob.learn.misc/KMeansMachine.h
@@ -12,7 +12,6 @@
 #include <cfloat>
 
 #include <bob.io.base/HDF5File.h>
-#include <bob.learn.misc/Machine.h>
 
 namespace bob { namespace learn { namespace misc {
 
@@ -20,7 +19,7 @@ namespace bob { namespace learn { namespace misc {
  * @brief This class implements a k-means classifier.
  * @details See Section 9.1 of Bishop, "Pattern recognition and machine learning", 2006
  */
-class KMeansMachine: public Machine<blitz::Array<double,1>, double> {
+class KMeansMachine {
   public:
     /**
      * Default constructor. Builds an otherwise invalid 0 x 0 k-means
@@ -118,7 +117,7 @@ class KMeansMachine: public Machine<blitz::Array<double,1>, double> {
      * @param[in]   i    The index of the mean
      * @param[out] mean The mean, a 1D array, with a length equal to the number of feature dimensions.
      */
-    void getMean(const size_t i, blitz::Array<double,1>& mean) const;
+    const blitz::Array<double,1> getMean(const size_t i) const;
 
     /**
      * Get the means (i.e. a 2D array, with as many rows as means, and as
diff --git a/bob/learn/misc/kmeans_machine.cpp b/bob/learn/misc/kmeans_machine.cpp
new file mode 100644
index 0000000..273705b
--- /dev/null
+++ b/bob/learn/misc/kmeans_machine.cpp
@@ -0,0 +1,498 @@
+/**
+ * @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+ * @date Fri 26 Dec 16:18:00 2014
+ *
+ * @brief Python API for bob::learn::em
+ *
+ * Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
+ */
+
+#include "main.h"
+
+/******************************************************************/
+/************ Constructor Section *********************************/
+/******************************************************************/
+
+static auto KMeansMachine_doc = bob::extension::ClassDoc(
+  BOB_EXT_MODULE_PREFIX ".KMeansMachine",
+  "This class implements a k-means classifier.\n"
+  "See Section 9.1 of Bishop, \"Pattern recognition and machine learning\", 2006"
+).add_constructor(
+  bob::extension::FunctionDoc(
+    "__init__",
+    "Creates a KMeansMachine",
+    "",
+    true
+  )
+  .add_prototype("n_gaussians,n_inputs","")
+  .add_prototype("other","")
+  .add_prototype("hdf5","")
+  .add_prototype("","")
+
+  .add_parameter("n_means", "int", "Number of means")
+  .add_parameter("n_inputs", "int", "Dimension of the feature vector")
+  .add_parameter("other", ":py:class:`bob.learn.misc.KMeansMachine`", "A KMeansMachine object to be copied.")
+  .add_parameter("hdf5", ":py:class:`bob.io.base.HDF5File`", "An HDF5 file open for reading")
+
+);
+
+
+static int PyBobLearnMiscKMeansMachine_init_number(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) {
+
+  char** kwlist = KMeansMachine_doc.kwlist(0);
+  int n_inputs    = 1;
+  int n_means = 1;
+  //Parsing the input argments
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "ii", kwlist, &n_means, &n_inputs))
+    return -1;
+
+  if(n_means < 0){
+    PyErr_Format(PyExc_TypeError, "means argument must be greater than or equal to zero");
+    KMeansMachine_doc.print_usage();
+    return -1;
+  }
+
+  if(n_inputs < 0){
+    PyErr_Format(PyExc_TypeError, "input argument must be greater than or equal to zero");
+    KMeansMachine_doc.print_usage();
+    return -1;
+   }
+
+  self->cxx.reset(new bob::learn::misc::KMeansMachine(n_means, n_inputs));
+  return 0;
+}
+
+
+static int PyBobLearnMiscKMeansMachine_init_copy(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) {
+
+  char** kwlist = KMeansMachine_doc.kwlist(1);
+  PyBobLearnMiscKMeansMachineObject* tt;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscKMeansMachine_Type, &tt)){
+    KMeansMachine_doc.print_usage();
+    return -1;
+  }
+
+  self->cxx.reset(new bob::learn::misc::KMeansMachine(*tt->cxx));
+  return 0;
+}
+
+
+static int PyBobLearnMiscKMeansMachine_init_hdf5(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) {
+
+  char** kwlist = KMeansMachine_doc.kwlist(2);
+
+  PyBobIoHDF5FileObject* config = 0;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBobIoHDF5File_Converter, &config)){
+    KMeansMachine_doc.print_usage();
+    return -1;
+  }
+
+  self->cxx.reset(new bob::learn::misc::KMeansMachine(*(config->f)));
+
+  return 0;
+}
+
+
+static int PyBobLearnMiscKMeansMachine_init(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  // get the number of command line arguments
+  int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
+  
+  switch (nargs) {
+
+    case 0: //default initializer ()
+      self->cxx.reset(new bob::learn::misc::KMeansMachine());
+      return 0;
+
+    case 1:{
+      //Reading the input argument
+      PyObject* arg = 0;
+      if (PyTuple_Size(args))
+        arg = PyTuple_GET_ITEM(args, 0);
+      else {
+        PyObject* tmp = PyDict_Values(kwargs);
+        auto tmp_ = make_safe(tmp);
+        arg = PyList_GET_ITEM(tmp, 0);
+      }
+
+      // If the constructor input is Gaussian object
+     if (PyBobLearnMiscKMeansMachine_Check(arg))
+       return PyBobLearnMiscKMeansMachine_init_copy(self, args, kwargs);
+      // If the constructor input is a HDF5
+     else if (PyBobIoHDF5File_Check(arg))
+       return PyBobLearnMiscKMeansMachine_init_hdf5(self, args, kwargs);
+    }
+    case 2:
+      return PyBobLearnMiscKMeansMachine_init_number(self, args, kwargs);
+    default:
+      PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires 0, 1 or 2 arguments, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs);
+      KMeansMachine_doc.print_usage();
+      return -1;
+  }
+  BOB_CATCH_MEMBER("cannot create KMeansMachine", 0)
+  return 0;
+}
+
+
+
+static void PyBobLearnMiscKMeansMachine_delete(PyBobLearnMiscKMeansMachineObject* self) {
+  self->cxx.reset();
+  Py_TYPE(self)->tp_free((PyObject*)self);
+}
+
+static PyObject* PyBobLearnMiscKMeansMachine_RichCompare(PyBobLearnMiscKMeansMachineObject* self, PyObject* other, int op) {
+  BOB_TRY
+
+  if (!PyBobLearnMiscKMeansMachine_Check(other)) {
+    PyErr_Format(PyExc_TypeError, "cannot compare `%s' with `%s'", Py_TYPE(self)->tp_name, Py_TYPE(other)->tp_name);
+    return 0;
+  }
+  auto other_ = reinterpret_cast<PyBobLearnMiscKMeansMachineObject*>(other);
+  switch (op) {
+    case Py_EQ:
+      if (*self->cxx==*other_->cxx) Py_RETURN_TRUE; else Py_RETURN_FALSE;
+    case Py_NE:
+      if (*self->cxx==*other_->cxx) Py_RETURN_FALSE; else Py_RETURN_TRUE;
+    default:
+      Py_INCREF(Py_NotImplemented);
+      return Py_NotImplemented;
+  }
+  BOB_CATCH_MEMBER("cannot compare KMeansMachine objects", 0)
+}
+
+int PyBobLearnMiscKMeansMachine_Check(PyObject* o) {
+  return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscKMeansMachine_Type));
+}
+
+
+/******************************************************************/
+/************ Variables Section ***********************************/
+/******************************************************************/
+
+/***** shape *****/
+static auto shape = bob::extension::VariableDoc(
+  "shape",
+  "(int,int)",
+  "A tuple that represents the number of means and dimensionality of the feature vector``(n_means, dim)``.",
+  ""
+);
+PyObject* PyBobLearnMiscKMeansMachine_getShape(PyBobLearnMiscKMeansMachineObject* self, void*) {
+  BOB_TRY
+  return Py_BuildValue("(i,i)", self->cxx->getNMeans(), self->cxx->getNInputs());
+  BOB_CATCH_MEMBER("shape could not be read", 0)
+}
+
+/***** MEAN *****/
+
+static auto means = bob::extension::VariableDoc(
+  "means",
+  "array_like <float, 2D>",
+  "The means",
+  ""
+);
+PyObject* PyBobLearnMiscKMeansMachine_getMeans(PyBobLearnMiscKMeansMachineObject* self, void*){
+  BOB_TRY
+  return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getMeans());
+  BOB_CATCH_MEMBER("means could not be read", 0)
+}
+int PyBobLearnMiscKMeansMachine_setMeans(PyBobLearnMiscKMeansMachineObject* self, PyObject* value, void*){
+  BOB_TRY
+  PyBlitzArrayObject* o;
+  if (!PyBlitzArray_Converter(value, &o)){
+    PyErr_Format(PyExc_RuntimeError, "%s %s expects a 2D array of floats", Py_TYPE(self)->tp_name, means.name());
+    return -1;
+  }
+  auto o_ = make_safe(o);
+  auto b = PyBlitzArrayCxx_AsBlitz<double,2>(o, "means");
+  if (!b) return -1;
+  self->cxx->setMeans(*b);
+  return 0;
+  BOB_CATCH_MEMBER("means could not be set", -1)
+}
+
+
+static PyGetSetDef PyBobLearnMiscKMeansMachine_getseters[] = { 
+  {
+   shape.name(),
+   (getter)PyBobLearnMiscKMeansMachine_getShape,
+   0,
+   shape.doc(),
+   0
+  },
+  {
+   means.name(),
+   (getter)PyBobLearnMiscKMeansMachine_getMeans,
+   (setter)PyBobLearnMiscKMeansMachine_setMeans,
+   means.doc(),
+   0
+  },
+  {0}  // Sentinel
+};
+
+
+/******************************************************************/
+/************ Functions Section ***********************************/
+/******************************************************************/
+
+
+/*** save ***/
+static auto save = bob::extension::FunctionDoc(
+  "save",
+  "Save the configuration of the KMeansMachine to a given HDF5 file"
+)
+.add_prototype("hdf5")
+.add_parameter("hdf5", ":py:class:`bob.io.base.HDF5File`", "An HDF5 file open for writing");
+static PyObject* PyBobLearnMiscKMeansMachine_Save(PyBobLearnMiscKMeansMachineObject* self,  PyObject* args, PyObject* kwargs) {
+
+  BOB_TRY
+  
+  // get list of arguments
+  char** kwlist = save.kwlist(0);  
+  PyBobIoHDF5FileObject* hdf5;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, PyBobIoHDF5File_Converter, &hdf5)) return 0;
+
+  auto hdf5_ = make_safe(hdf5);
+  self->cxx->save(*hdf5->f);
+
+  BOB_CATCH_MEMBER("cannot save the data", 0)
+  Py_RETURN_NONE;
+}
+
+/*** load ***/
+static auto load = bob::extension::FunctionDoc(
+  "load",
+  "Load the configuration of the KMeansMachine to a given HDF5 file"
+)
+.add_prototype("hdf5")
+.add_parameter("hdf5", ":py:class:`bob.io.base.HDF5File`", "An HDF5 file open for reading");
+static PyObject* PyBobLearnMiscKMeansMachine_Load(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+  
+  char** kwlist = load.kwlist(0);  
+  PyBobIoHDF5FileObject* hdf5;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, PyBobIoHDF5File_Converter, &hdf5)) return 0;
+  
+  auto hdf5_ = make_safe(hdf5);  
+  self->cxx->load(*hdf5->f);
+
+  BOB_CATCH_MEMBER("cannot load the data", 0)
+  Py_RETURN_NONE;
+}
+
+
+/*** is_similar_to ***/
+static auto is_similar_to = bob::extension::FunctionDoc(
+  "is_similar_to",
+  
+  "Compares this KMeansMachine with the ``other`` one to be approximately the same.",
+  "The optional values ``r_epsilon`` and ``a_epsilon`` refer to the "
+  "relative and absolute precision for the ``weights``, ``biases`` "
+  "and any other values internal to this machine."
+)
+.add_prototype("other, [r_epsilon], [a_epsilon]","output")
+.add_parameter("other", ":py:class:`bob.learn.misc.KMeansMachine`", "A KMeansMachine object to be compared.")
+.add_parameter("r_epsilon", "float", "Relative precision.")
+.add_parameter("a_epsilon", "float", "Absolute precision.")
+.add_return("output","bool","True if it is similar, otherwise false.");
+static PyObject* PyBobLearnMiscKMeansMachine_IsSimilarTo(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwds) {
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = is_similar_to.kwlist(0);
+
+  //PyObject* other = 0;
+  PyBobLearnMiscKMeansMachineObject* other = 0;
+  double r_epsilon = 1.e-5;
+  double a_epsilon = 1.e-8;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|dd", kwlist,
+        &PyBobLearnMiscKMeansMachine_Type, &other,
+        &r_epsilon, &a_epsilon)){
+
+        is_similar_to.print_usage(); 
+        return 0;        
+  }
+
+  if (self->cxx->is_similar_to(*other->cxx, r_epsilon, a_epsilon))
+    Py_RETURN_TRUE;
+  else
+    Py_RETURN_FALSE;
+}
+
+
+/*** resize ***/
+static auto resize = bob::extension::FunctionDoc(
+  "resize",
+  "Allocates space for the statistics and resets to zero.",
+  0,
+  true
+)
+.add_prototype("n_means,n_inputs")
+.add_parameter("n_means", "int", "Number of means")
+.add_parameter("n_inputs", "int", "Dimensionality of the feature vector");
+static PyObject* PyBobLearnMiscKMeansMachine_resize(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = resize.kwlist(0);
+
+  int n_means = 0;
+  int n_inputs = 0;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "ii", kwlist, &n_means, &n_inputs)) Py_RETURN_NONE;
+
+  if (n_means <= 0){
+    PyErr_Format(PyExc_TypeError, "n_means must be greater than zero");
+    resize.print_usage();
+    return 0;
+  }
+  if (n_inputs <= 0){
+    PyErr_Format(PyExc_TypeError, "n_inputs must be greater than zero");
+    resize.print_usage();
+    return 0;
+  }
+
+  self->cxx->resize(n_means, n_inputs);
+
+  BOB_CATCH_MEMBER("cannot perform the resize method", 0)
+
+  Py_RETURN_NONE;
+}
+
+/*** get_mean ***/
+static auto get_mean = bob::extension::FunctionDoc(
+  "get_mean",
+  "Get the i'th mean.",
+  ".. note:: An exception is thrown if i is out of range.", 
+  true
+)
+.add_prototype("i","mean index")
+.add_parameter("i", "int", "Index of the mean")
+.add_return("mean","array_like <float, 1D>","Mean array");
+static PyObject* PyBobLearnMiscKMeansMachine_get_mean(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+  
+  char** kwlist = get_mean.kwlist(0);
+
+  int i = 0;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "i", kwlist, &i)) Py_RETURN_NONE;
+ 
+  return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getMean(i));
+
+  BOB_CATCH_MEMBER("cannot compute the likelihood", 0)
+}
+
+
+/*** get_distance_from_mean ***/
+static auto get_distance_from_mean = bob::extension::FunctionDoc(
+  "get_distance_from_mean",
+  "Return the power of two of the square Euclidean distance of the sample, x, to the i'th mean.",
+  ".. note:: An exception is thrown if i is out of range.", 
+  true
+)
+.add_prototype("input,i","output")
+.add_parameter("input", "array_like <float, 1D>", "The data sample (feature vector)")
+.add_parameter("i", "int", "The index of the mean")
+.add_return("output","float","Square Euclidean distance of the sample, x, to the i'th mean");
+static PyObject* PyBobLearnMiscKMeansMachine_get_distance_from_mean(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+  
+  char** kwlist = get_distance_from_mean.kwlist(0);
+
+  PyBlitzArrayObject* input = 0;
+  int i = 0;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&i", kwlist, &PyBlitzArray_Converter, &input, &i)){ 
+    Py_RETURN_NONE;
+  }
+
+  //protects acquired resources through this scope
+  auto input_ = make_safe(input);
+
+  //return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getMean(i));
+  double output = self->cxx->getDistanceFromMean(*PyBlitzArrayCxx_AsBlitz<double,1>(input),i);
+  return Py_BuildValue("d", output);
+
+  BOB_CATCH_MEMBER("cannot compute the likelihood", 0)
+}
+
+
+
+
+static PyMethodDef PyBobLearnMiscKMeansMachine_methods[] = {
+  {
+    save.name(),
+    (PyCFunction)PyBobLearnMiscKMeansMachine_Save,
+    METH_VARARGS|METH_KEYWORDS,
+    save.doc()
+  },
+  {
+    load.name(),
+    (PyCFunction)PyBobLearnMiscKMeansMachine_Load,
+    METH_VARARGS|METH_KEYWORDS,
+    load.doc()
+  },
+  {
+    is_similar_to.name(),
+    (PyCFunction)PyBobLearnMiscKMeansMachine_IsSimilarTo,
+    METH_VARARGS|METH_KEYWORDS,
+    is_similar_to.doc()
+  },
+  {
+    resize.name(),
+    (PyCFunction)PyBobLearnMiscKMeansMachine_resize,
+    METH_VARARGS|METH_KEYWORDS,
+    resize.doc()
+  },  
+  {
+    get_mean.name(),
+    (PyCFunction)PyBobLearnMiscKMeansMachine_get_mean,
+    METH_VARARGS|METH_KEYWORDS,
+    get_mean.doc()
+  },  
+  {
+    get_distance_from_mean.name(),
+    (PyCFunction)PyBobLearnMiscKMeansMachine_get_distance_from_mean,
+    METH_VARARGS|METH_KEYWORDS,
+    get_distance_from_mean.doc()
+  },  
+
+  {0} /* Sentinel */
+};
+
+
+/******************************************************************/
+/************ Module Section **************************************/
+/******************************************************************/
+
+// Define the Gaussian type struct; will be initialized later
+PyTypeObject PyBobLearnMiscKMeansMachine_Type = {
+  PyVarObject_HEAD_INIT(0,0)
+  0
+};
+
+bool init_BobLearnMiscKMeansMachine(PyObject* module)
+{
+  // initialize the type struct
+  PyBobLearnMiscKMeansMachine_Type.tp_name = KMeansMachine_doc.name();
+  PyBobLearnMiscKMeansMachine_Type.tp_basicsize = sizeof(PyBobLearnMiscKMeansMachineObject);
+  PyBobLearnMiscKMeansMachine_Type.tp_flags = Py_TPFLAGS_DEFAULT;
+  PyBobLearnMiscKMeansMachine_Type.tp_doc = KMeansMachine_doc.doc();
+
+  // set the functions
+  PyBobLearnMiscKMeansMachine_Type.tp_new = PyType_GenericNew;
+  PyBobLearnMiscKMeansMachine_Type.tp_init = reinterpret_cast<initproc>(PyBobLearnMiscKMeansMachine_init);
+  PyBobLearnMiscKMeansMachine_Type.tp_dealloc = reinterpret_cast<destructor>(PyBobLearnMiscKMeansMachine_delete);
+  PyBobLearnMiscKMeansMachine_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscKMeansMachine_RichCompare);
+  PyBobLearnMiscKMeansMachine_Type.tp_methods = PyBobLearnMiscKMeansMachine_methods;
+  PyBobLearnMiscKMeansMachine_Type.tp_getset = PyBobLearnMiscKMeansMachine_getseters;
+  //PyBobLearnMiscGMMMachine_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscGMMMachine_loglikelihood);
+
+
+  // check that everything is fine
+  if (PyType_Ready(&PyBobLearnMiscKMeansMachine_Type) < 0) return false;
+
+  // add the type to the module
+  Py_INCREF(&PyBobLearnMiscKMeansMachine_Type);
+  return PyModule_AddObject(module, "KMeansMachine", (PyObject*)&PyBobLearnMiscKMeansMachine_Type) >= 0;
+}
+
diff --git a/bob/learn/misc/main.cpp b/bob/learn/misc/main.cpp
index 64618ae..8b683b3 100644
--- a/bob/learn/misc/main.cpp
+++ b/bob/learn/misc/main.cpp
@@ -42,7 +42,8 @@ static PyObject* create_module (void) {
   if (PyModule_AddStringConstant(module, "__version__", BOB_EXT_MODULE_VERSION) < 0) return 0;
   if (!init_BobLearnMiscGaussian(module)) return 0;
   if (!init_BobLearnMiscGMMStats(module)) return 0;
-  if (!init_BobLearnMiscGMMMachine(module)) return 0;  
+  if (!init_BobLearnMiscGMMMachine(module)) return 0;
+  if (!init_BobLearnMiscKMeansMachine(module)) return 0;  
 
   static void* PyBobLearnMisc_API[PyBobLearnMisc_API_pointers];
 
diff --git a/bob/learn/misc/main.h b/bob/learn/misc/main.h
index 79830c6..2f727f3 100644
--- a/bob/learn/misc/main.h
+++ b/bob/learn/misc/main.h
@@ -20,6 +20,7 @@
 #include <bob.learn.misc/Gaussian.h>
 #include <bob.learn.misc/GMMStats.h>
 #include <bob.learn.misc/GMMMachine.h>
+#include <bob.learn.misc/KMeansMachine.h>
 
 
 #if PY_VERSION_HEX >= 0x03000000
@@ -93,5 +94,16 @@ bool init_BobLearnMiscGMMMachine(PyObject* module);
 int PyBobLearnMiscGMMMachine_Check(PyObject* o);
 
 
+// KMeansMachine
+typedef struct {
+  PyObject_HEAD
+  boost::shared_ptr<bob::learn::misc::KMeansMachine> cxx;
+} PyBobLearnMiscKMeansMachineObject;
+
+extern PyTypeObject PyBobLearnMiscKMeansMachine_Type;
+bool init_BobLearnMiscKMeansMachine(PyObject* module);
+int PyBobLearnMiscKMeansMachine_Check(PyObject* o);
+
+
 
 #endif // BOB_LEARN_EM_MAIN_H
diff --git a/bob/learn/misc/test_kmeans.py b/bob/learn/misc/test_kmeans.py
index cec6cf1..d55ed8b 100644
--- a/bob/learn/misc/test_kmeans.py
+++ b/bob/learn/misc/test_kmeans.py
@@ -27,18 +27,25 @@ def test_KMeansMachine():
   # Initializes a KMeansMachine
   km = KMeansMachine(2,3)
   km.means = means
-  assert km.dim_c == 2
-  assert km.dim_d == 3
+  assert km.shape == (2,3)
 
   # Sets and gets
   assert (km.means == means).all()
   assert (km.get_mean(0) == means[0,:]).all()
   assert (km.get_mean(1) == means[1,:]).all()
-  km.set_mean(0, mean)
-  assert (km.get_mean(0) == mean).all()
+  #km.set_mean(0, mean)
+  #assert (km.get_mean(0) == mean).all()
 
   # Distance and closest mean
   eps = 1e-10
+
+  print mean
+  print km.means
+
+  print km.get_distance_from_mean(mean, 0)
+
+
+
   assert equals( km.get_distance_from_mean(mean, 0), 0, eps)
   assert equals( km.get_distance_from_mean(mean, 1), 6, eps)
   (index, dist) = km.get_closest_mean(mean)
diff --git a/setup.py b/setup.py
index e9c4f69..e5c8797 100644
--- a/setup.py
+++ b/setup.py
@@ -59,7 +59,7 @@ setup(
           "bob/learn/misc/cpp/GMMStats.cpp",
           #"bob/learn/misc/cpp/IVectorMachine.cpp",
           #"bob/learn/misc/cpp/JFAMachine.cpp",
-          #"bob/learn/misc/cpp/KMeansMachine.cpp",
+          "bob/learn/misc/cpp/KMeansMachine.cpp",
           #"bob/learn/misc/cpp/LinearScoring.cpp",
           #"bob/learn/misc/cpp/PLDAMachine.cpp",
           #"bob/learn/misc/cpp/ZTNorm.cpp",
@@ -103,6 +103,7 @@ setup(
           "bob/learn/misc/gaussian.cpp",
           "bob/learn/misc/gmm_stats.cpp",
           "bob/learn/misc/gmm_machine.cpp",
+          "bob/learn/misc/kmeans_machine.cpp",
 
           "bob/learn/misc/main.cpp",
         ],
-- 
GitLab