diff --git a/bob/learn/em/map_gmm_trainer.cpp b/bob/learn/em/map_gmm_trainer.cpp index ed4b6ed759e866a940f2064cabcfc47e636ddc32..178daaa4e39fb4d05ef49a7b01fc9217de544b21 100644 --- a/bob/learn/em/map_gmm_trainer.cpp +++ b/bob/learn/em/map_gmm_trainer.cpp @@ -306,9 +306,10 @@ static auto initialize = bob::extension::FunctionDoc( "", true ) -.add_prototype("gmm_machine, [data]") +.add_prototype("gmm_machine, [data], [rng]") .add_parameter("gmm_machine", ":py:class:`bob.learn.em.GMMMachine`", "GMMMachine Object") -.add_parameter("data", "object", "Ignored."); +.add_parameter("data", "object", "Ignored.") +.add_parameter("rng", "object", "Ignored."); static PyObject* PyBobLearnEMMAPGMMTrainer_initialize(PyBobLearnEMMAPGMMTrainerObject* self, PyObject* args, PyObject* kwargs) { BOB_TRY @@ -317,9 +318,10 @@ static PyObject* PyBobLearnEMMAPGMMTrainer_initialize(PyBobLearnEMMAPGMMTrainerO PyBobLearnEMGMMMachineObject* gmm_machine = 0; PyObject* data; + PyObject* rng; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O", kwlist, &PyBobLearnEMGMMMachine_Type, &gmm_machine, - &data)) return 0; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|OO", kwlist, &PyBobLearnEMGMMMachine_Type, &gmm_machine, + &data, &rng)) return 0; self->cxx->initialize(*gmm_machine->cxx); diff --git a/bob/learn/em/ml_gmm_trainer.cpp b/bob/learn/em/ml_gmm_trainer.cpp index 734fe2b9e910702519ef6129830e1a38df1e309a..78da00fc77ba395da0b0ced9e93023f48da9f254 100644 --- a/bob/learn/em/ml_gmm_trainer.cpp +++ b/bob/learn/em/ml_gmm_trainer.cpp @@ -192,9 +192,10 @@ static auto initialize = bob::extension::FunctionDoc( "", true ) -.add_prototype("gmm_machine, [data]") +.add_prototype("gmm_machine, [data], [rng]") .add_parameter("gmm_machine", ":py:class:`bob.learn.em.GMMMachine`", "GMMMachine Object") -.add_parameter("data", "object", "Ignored."); +.add_parameter("data", "object", "Ignored.") +.add_parameter("rng", "object", "Ignored."); static PyObject* PyBobLearnEMMLGMMTrainer_initialize(PyBobLearnEMMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) { BOB_TRY @@ -202,9 +203,10 @@ static PyObject* PyBobLearnEMMLGMMTrainer_initialize(PyBobLearnEMMLGMMTrainerObj char** kwlist = initialize.kwlist(0); PyBobLearnEMGMMMachineObject* gmm_machine = 0; PyObject* data; + PyObject* rng; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O", kwlist, &PyBobLearnEMGMMMachine_Type, &gmm_machine, - &data)) return 0; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|OO", kwlist, &PyBobLearnEMGMMMachine_Type, &gmm_machine, + &data, &rng)) return 0; self->cxx->initialize(*gmm_machine->cxx); BOB_CATCH_MEMBER("cannot perform the initialize method", 0) diff --git a/bob/learn/em/train.py b/bob/learn/em/train.py index 41114cca50bc21f43085a65dee3595f7332d2dac..beb93373d7aeff03594ee81dd37fca6cf505f1b2 100644 --- a/bob/learn/em/train.py +++ b/bob/learn/em/train.py @@ -45,9 +45,7 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None, # Initialization if initialize: - if rng is not None and \ - (not isinstance(trainer, (bob.learn.em.ML_GMMTrainer, - bob.learn.em.MAP_GMMTrainer))): + if rng is not None: trainer.initialize(machine, data, rng) else: trainer.initialize(machine, data) diff --git a/version.txt b/version.txt index ecb8de97318459bc1b0446fa71a7ef56159b3ff9..1a78b34537ab70360822a5d0733a88ab70b02247 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -2.0.14b0 \ No newline at end of file +2.1.0b0