diff --git a/bob/learn/em/ivector_trainer.cpp b/bob/learn/em/ivector_trainer.cpp index 9e630548508398b13759be0f5b214af80519468e..7524afb524d4f582a67f55c19389006ca2459506 100644 --- a/bob/learn/em/ivector_trainer.cpp +++ b/bob/learn/em/ivector_trainer.cpp @@ -406,6 +406,29 @@ static PyObject* PyBobLearnEMIVectorTrainer_m_step(PyBobLearnEMIVectorTrainerObj Py_RETURN_NONE; } +/*** reset_accumulators ***/ +static auto reset_accumulators = bob::extension::FunctionDoc( + "reset_accumulators", + "Reset the statistics accumulators to the correct size and a value of zero.", + 0, + true +) +.add_prototype("ivector_machine") +.add_parameter("ivector_machine", ":py:class:`bob.learn.em.IVectorMachine`", "The IVector machine containing the right dimensions"); +static PyObject* PyBobLearnEMIVectorTrainer_reset_accumulators(PyBobLearnEMIVectorTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = reset_accumulators.kwlist(0); + + PyBobLearnEMIVectorMachineObject* machine; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnEMIVectorMachine_Type, &machine)) return 0; + + self->cxx->resetAccumulators(*machine->cxx); + Py_RETURN_NONE; + + BOB_CATCH_MEMBER("cannot perform the reset_accumulators method", 0) +} static PyMethodDef PyBobLearnEMIVectorTrainer_methods[] = { @@ -427,6 +450,12 @@ static PyMethodDef PyBobLearnEMIVectorTrainer_methods[] = { METH_VARARGS|METH_KEYWORDS, m_step.doc() }, + { + reset_accumulators.name(), + (PyCFunction)PyBobLearnEMIVectorTrainer_reset_accumulators, + METH_VARARGS|METH_KEYWORDS, + reset_accumulators.doc() + }, {0} /* Sentinel */ }; diff --git a/bob/learn/em/test/test_ivector_trainer.py b/bob/learn/em/test/test_ivector_trainer.py index 753925366d298b89792c076d00a6cc8636e1433d..44a8142effa083f52c54d23f2a90ab22da2e28f5 100644 --- a/bob/learn/em/test/test_ivector_trainer.py +++ b/bob/learn/em/test/test_ivector_trainer.py @@ -244,11 +244,11 @@ def test_trainer_nosigma(): # M-Step trainer.m_step(m) assert numpy.allclose(t_ref[it], m.t, 1e-5) - + #testing exceptions nose.tools.assert_raises(RuntimeError, trainer.e_step, m, [1,2,2]) - + def test_trainer_update_sigma(): # Ubm