diff --git a/bob/learn/misc/gmm_machine.cpp b/bob/learn/misc/gmm_machine.cpp index 05e4ad0926341f0e895dcb9bbadef09708ae5d47..356a800e0a19a596e411ca5a7e5d9973ecb659a7 100644 --- a/bob/learn/misc/gmm_machine.cpp +++ b/bob/learn/misc/gmm_machine.cpp @@ -645,14 +645,20 @@ static PyObject* PyBobLearnMiscGMMMachine_setVarianceThresholds_method(PyBobLear char** kwlist = set_variance_thresholds.kwlist(0); - PyBlitzArrayObject* input = 0; + PyBlitzArrayObject* input_array = 0; + double input_number = 0; + if(PyArg_ParseTupleAndKeywords(args, kwargs, "d", kwlist, &input_number)){ + self->cxx->setVarianceThresholds(input_number); + } + else if(PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBlitzArray_Converter,&input_array)) { + //protects acquired resources through this scope + auto input_ = make_safe(input_array); + self->cxx->setVarianceThresholds(*PyBlitzArrayCxx_AsBlitz<double,1>(input_array)); + } + else + return 0; - if(!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBlitzArray_Converter,&input)) - Py_RETURN_NONE; - //protects acquired resources through this scope - auto input_ = make_safe(input); - self->cxx->setVarianceThresholds(*PyBlitzArrayCxx_AsBlitz<double,1>(input)); BOB_CATCH_MEMBER("cannot accumulate set the variance threshold", 0) Py_RETURN_NONE; diff --git a/bob/learn/misc/test_gmm.py b/bob/learn/misc/test_gmm.py index c36f85b248f5d8618348d86b5bdc530235b34600..a9e0910b9e1e6974594afa14f80acb757c5a50d8 100644 --- a/bob/learn/misc/test_gmm.py +++ b/bob/learn/misc/test_gmm.py @@ -148,10 +148,13 @@ def test_GMMMachine_1(): gmm.set_variance_thresholds(varianceThresholds1D) assert (gmm.variance_thresholds[0,:] == varianceThresholds1D).all() assert (gmm.variance_thresholds[1,:] == varianceThresholds1D).all() + gmm.set_variance_thresholds(0.005) - #assert (gmm.variance_thresholds == 0.005).all() + assert (gmm.variance_thresholds == 0.005).all() # Checks Gaussians access + gmm.means = newMeans + gmm.variances = newVariances assert (gmm.get_gaussian(0).mean == newMeans[0,:]).all() assert (gmm.get_gaussian(1).mean == newMeans[1,:]).all() assert (gmm.get_gaussian(0).variance == newVariances[0,:]).all()