diff --git a/bob/learn/libsvm/test_trainer.py b/bob/learn/libsvm/test_trainer.py index 060326caa34d463886732932f52116e7c6a4971b..74f4b4d045671e3e448c54f6c739e0d0c8ff38dc 100644 --- a/bob/learn/libsvm/test_trainer.py +++ b/bob/learn/libsvm/test_trainer.py @@ -214,4 +214,21 @@ def test_training_one_class(): curr_scores = numpy.array(curr_scores) prev_scores = numpy.array(prev_scores) assert numpy.all(abs(curr_scores - prev_scores) < 1e-8) - + + +def test_successive_training(): + + # Tests successive training works: i.e., training a couple of machines one + # after the other. + + numpy.random.seed(10) + + for i in range(2): + pos = numpy.random.normal(0., 1, size=(100, 2)) + neg = numpy.random.normal(1., 1, size=(100, 2)) + data = [pos, neg] + + trainer = Trainer() + trainer.kernel_type = 'LINEAR' + trainer.cost = 1 + trainer.train(data) diff --git a/bob/learn/libsvm/trainer.cpp b/bob/learn/libsvm/trainer.cpp index 5c2a9e362c4c90e9e0b71fc23aab30f734994184..3998949de72ea46fe1b18480dc3f0fcbe1d9b5b0 100644 --- a/bob/learn/libsvm/trainer.cpp +++ b/bob/learn/libsvm/trainer.cpp @@ -596,23 +596,19 @@ static PyObject* PyBobLearnLibsvmTrainer_train &PyBlitzArray_OutputConverter, ÷ )) return 0; + // do not decref X, otherwise it will be deleted by Python //protects acquired resources through this scope - auto X_ = make_safe(X); auto subtract_ = make_xsafe(subtract); auto divide_ = make_xsafe(divide); - /** - // Note: strangely, if you pass dict.values(), this check does not work - if (!PyIter_Check(X)) { - PyErr_Format(PyExc_TypeError, "`%s' requires an iterable for parameter `X', but you passed `%s' which does not implement the iterator protocol", Py_TYPE(self)->tp_name, Py_TYPE(X)->tp_name); - return 0; - } - **/ - /* Checks and converts all entries */ std::vector<blitz::Array<double,2> > Xseq; std::vector<boost::shared_ptr<PyBlitzArrayObject>> Xseq_; + /* The standard way to check if a python object is iterable is this. + * PyIter_Check will only check if the object is of class ``iterable``. This + * will not work as you expect + */ PyObject* iterator = PyObject_GetIter(X); if (!iterator) return 0; auto iterator_ = make_safe(iterator); @@ -634,7 +630,7 @@ static PyObject* PyBobLearnLibsvmTrainer_train } Xseq_.push_back(make_safe(bz)); ///< prevents data deletion - Xseq.push_back(*PyBlitzArrayCxx_AsBlitz<double,2>(bz)); ///< only a view! + Xseq.push_back(*PyBlitzArrayCxx_AsBlitz<double,2>(bz)); ///< only a view! } if (PyErr_Occurred()) return 0; @@ -646,7 +642,7 @@ static PyObject* PyBobLearnLibsvmTrainer_train PyErr_Format(PyExc_RuntimeError, "`%s' requires an iterable for parameter `X' leading to, at least, two entries (representing two classes), but you have passed something that has only %" PY_FORMAT_SIZE_T "d entries", Py_TYPE(self)->tp_name, Xseq.size()); return 0; } - + if ( (Xseq.size() < 1) && (self->cxx->getMachineType()==bob::learn::libsvm::machine_t::ONE_CLASS) ) { PyErr_Format(PyExc_RuntimeError, "`%s' requires an iterable for parameter `X' leading to, at least, one entry (representing one class), but you have passed something that has only %" PY_FORMAT_SIZE_T "d entries", Py_TYPE(self)->tp_name, Xseq.size()); return 0; @@ -675,16 +671,10 @@ static PyObject* PyBobLearnLibsvmTrainer_train //std::cout << "all basic checks are done, can call the machine now..." << std::endl; try { bob::learn::libsvm::Machine* machine; - - if(self->cxx->getMachineType()==bob::learn::libsvm::machine_t::ONE_CLASS) - { - if (subtract && divide) machine = self->cxx->train(Xseq,*PyBlitzArrayCxx_AsBlitz<double,1>(subtract),*PyBlitzArrayCxx_AsBlitz<double,1>(divide)); - else machine = self->cxx->train(Xseq); - } - else { - if (subtract && divide) machine = self->cxx->train(Xseq,*PyBlitzArrayCxx_AsBlitz<double,1>(subtract),*PyBlitzArrayCxx_AsBlitz<double,1>(divide)); - else machine = self->cxx->train(Xseq); - } + + if (subtract && divide) machine = self->cxx->train(Xseq,*PyBlitzArrayCxx_AsBlitz<double,1>(subtract),*PyBlitzArrayCxx_AsBlitz<double,1>(divide)); + else machine = self->cxx->train(Xseq); + return PyBobLearnLibsvmMachine_NewFromMachine(machine); } catch (std::exception& e) {