From fe3864d082f246bdb4470efa75143e32a138767a Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Fri, 26 Dec 2014 16:22:23 -0200
Subject: [PATCH] Fixed the set_variance_threshold method

---
 bob/learn/misc/gmm_machine.cpp | 18 ++++++++++++------
 bob/learn/misc/test_gmm.py     |  5 ++++-
 2 files changed, 16 insertions(+), 7 deletions(-)

diff --git a/bob/learn/misc/gmm_machine.cpp b/bob/learn/misc/gmm_machine.cpp
index 05e4ad0..356a800 100644
--- a/bob/learn/misc/gmm_machine.cpp
+++ b/bob/learn/misc/gmm_machine.cpp
@@ -645,14 +645,20 @@ static PyObject* PyBobLearnMiscGMMMachine_setVarianceThresholds_method(PyBobLear
 
   char** kwlist = set_variance_thresholds.kwlist(0);
 
-  PyBlitzArrayObject* input = 0;
+  PyBlitzArrayObject* input_array = 0;
+  double input_number = 0;
+  if(PyArg_ParseTupleAndKeywords(args, kwargs, "d", kwlist, &input_number)){
+    self->cxx->setVarianceThresholds(input_number);
+  }
+  else if(PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBlitzArray_Converter,&input_array)) {
+    //protects acquired resources through this scope
+    auto input_ = make_safe(input_array);
+    self->cxx->setVarianceThresholds(*PyBlitzArrayCxx_AsBlitz<double,1>(input_array));
+  }
+  else
+    return 0;
 
- if(!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBlitzArray_Converter,&input))
-    Py_RETURN_NONE;
 
-  //protects acquired resources through this scope
-  auto input_ = make_safe(input);
-  self->cxx->setVarianceThresholds(*PyBlitzArrayCxx_AsBlitz<double,1>(input));
 
   BOB_CATCH_MEMBER("cannot accumulate set the variance threshold", 0)
   Py_RETURN_NONE;
diff --git a/bob/learn/misc/test_gmm.py b/bob/learn/misc/test_gmm.py
index c36f85b..a9e0910 100644
--- a/bob/learn/misc/test_gmm.py
+++ b/bob/learn/misc/test_gmm.py
@@ -148,10 +148,13 @@ def test_GMMMachine_1():
   gmm.set_variance_thresholds(varianceThresholds1D)
   assert (gmm.variance_thresholds[0,:] == varianceThresholds1D).all()
   assert (gmm.variance_thresholds[1,:] == varianceThresholds1D).all()
+
   gmm.set_variance_thresholds(0.005)
-  #assert (gmm.variance_thresholds == 0.005).all()
+  assert (gmm.variance_thresholds == 0.005).all()
 
   # Checks Gaussians access
+  gmm.means     = newMeans
+  gmm.variances = newVariances
   assert (gmm.get_gaussian(0).mean == newMeans[0,:]).all()
   assert (gmm.get_gaussian(1).mean == newMeans[1,:]).all()
   assert (gmm.get_gaussian(0).variance == newVariances[0,:]).all()
-- 
GitLab