From fe3864d082f246bdb4470efa75143e32a138767a Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Fri, 26 Dec 2014 16:22:23 -0200 Subject: [PATCH] Fixed the set_variance_threshold method --- bob/learn/misc/gmm_machine.cpp | 18 ++++++++++++------ bob/learn/misc/test_gmm.py | 5 ++++- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/bob/learn/misc/gmm_machine.cpp b/bob/learn/misc/gmm_machine.cpp index 05e4ad0..356a800 100644 --- a/bob/learn/misc/gmm_machine.cpp +++ b/bob/learn/misc/gmm_machine.cpp @@ -645,14 +645,20 @@ static PyObject* PyBobLearnMiscGMMMachine_setVarianceThresholds_method(PyBobLear char** kwlist = set_variance_thresholds.kwlist(0); - PyBlitzArrayObject* input = 0; + PyBlitzArrayObject* input_array = 0; + double input_number = 0; + if(PyArg_ParseTupleAndKeywords(args, kwargs, "d", kwlist, &input_number)){ + self->cxx->setVarianceThresholds(input_number); + } + else if(PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBlitzArray_Converter,&input_array)) { + //protects acquired resources through this scope + auto input_ = make_safe(input_array); + self->cxx->setVarianceThresholds(*PyBlitzArrayCxx_AsBlitz<double,1>(input_array)); + } + else + return 0; - if(!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBlitzArray_Converter,&input)) - Py_RETURN_NONE; - //protects acquired resources through this scope - auto input_ = make_safe(input); - self->cxx->setVarianceThresholds(*PyBlitzArrayCxx_AsBlitz<double,1>(input)); BOB_CATCH_MEMBER("cannot accumulate set the variance threshold", 0) Py_RETURN_NONE; diff --git a/bob/learn/misc/test_gmm.py b/bob/learn/misc/test_gmm.py index c36f85b..a9e0910 100644 --- a/bob/learn/misc/test_gmm.py +++ b/bob/learn/misc/test_gmm.py @@ -148,10 +148,13 @@ def test_GMMMachine_1(): gmm.set_variance_thresholds(varianceThresholds1D) assert (gmm.variance_thresholds[0,:] == varianceThresholds1D).all() assert (gmm.variance_thresholds[1,:] == varianceThresholds1D).all() + gmm.set_variance_thresholds(0.005) - #assert (gmm.variance_thresholds == 0.005).all() + assert (gmm.variance_thresholds == 0.005).all() # Checks Gaussians access + gmm.means = newMeans + gmm.variances = newVariances assert (gmm.get_gaussian(0).mean == newMeans[0,:]).all() assert (gmm.get_gaussian(1).mean == newMeans[1,:]).all() assert (gmm.get_gaussian(0).variance == newVariances[0,:]).all() -- GitLab