Commit 2eea542f authored by André Anjos's avatar André Anjos

Merge branch 'issue-9' into 'master'

Fix segmentation fault when SVM is trained in sequenece. Fixes issue #9

See merge request !3
parents 60d4d357 fb9c89cf
Pipeline #9853 passed with stages
in 23 minutes and 4 seconds
......@@ -214,4 +214,21 @@ def test_training_one_class():
curr_scores = numpy.array(curr_scores)
prev_scores = numpy.array(prev_scores)
assert numpy.all(abs(curr_scores - prev_scores) < 1e-8)
def test_successive_training():
# Tests successive training works: i.e., training a couple of machines one
# after the other.
numpy.random.seed(10)
for i in range(2):
pos = numpy.random.normal(0., 1, size=(100, 2))
neg = numpy.random.normal(1., 1, size=(100, 2))
data = [pos, neg]
trainer = Trainer()
trainer.kernel_type = 'LINEAR'
trainer.cost = 1
trainer.train(data)
......@@ -596,23 +596,19 @@ static PyObject* PyBobLearnLibsvmTrainer_train
&PyBlitzArray_OutputConverter, &divide
)) return 0;
// do not decref X, otherwise it will be deleted by Python
//protects acquired resources through this scope
auto X_ = make_safe(X);
auto subtract_ = make_xsafe(subtract);
auto divide_ = make_xsafe(divide);
/**
// Note: strangely, if you pass dict.values(), this check does not work
if (!PyIter_Check(X)) {
PyErr_Format(PyExc_TypeError, "`%s' requires an iterable for parameter `X', but you passed `%s' which does not implement the iterator protocol", Py_TYPE(self)->tp_name, Py_TYPE(X)->tp_name);
return 0;
}
**/
/* Checks and converts all entries */
std::vector<blitz::Array<double,2> > Xseq;
std::vector<boost::shared_ptr<PyBlitzArrayObject>> Xseq_;
/* The standard way to check if a python object is iterable is this.
* PyIter_Check will only check if the object is of class ``iterable``. This
* will not work as you expect
*/
PyObject* iterator = PyObject_GetIter(X);
if (!iterator) return 0;
auto iterator_ = make_safe(iterator);
......@@ -634,7 +630,7 @@ static PyObject* PyBobLearnLibsvmTrainer_train
}
Xseq_.push_back(make_safe(bz)); ///< prevents data deletion
Xseq.push_back(*PyBlitzArrayCxx_AsBlitz<double,2>(bz)); ///< only a view!
Xseq.push_back(*PyBlitzArrayCxx_AsBlitz<double,2>(bz)); ///< only a view!
}
if (PyErr_Occurred()) return 0;
......@@ -646,7 +642,7 @@ static PyObject* PyBobLearnLibsvmTrainer_train
PyErr_Format(PyExc_RuntimeError, "`%s' requires an iterable for parameter `X' leading to, at least, two entries (representing two classes), but you have passed something that has only %" PY_FORMAT_SIZE_T "d entries", Py_TYPE(self)->tp_name, Xseq.size());
return 0;
}
if ( (Xseq.size() < 1) && (self->cxx->getMachineType()==bob::learn::libsvm::machine_t::ONE_CLASS) ) {
PyErr_Format(PyExc_RuntimeError, "`%s' requires an iterable for parameter `X' leading to, at least, one entry (representing one class), but you have passed something that has only %" PY_FORMAT_SIZE_T "d entries", Py_TYPE(self)->tp_name, Xseq.size());
return 0;
......@@ -675,16 +671,10 @@ static PyObject* PyBobLearnLibsvmTrainer_train
//std::cout << "all basic checks are done, can call the machine now..." << std::endl;
try {
bob::learn::libsvm::Machine* machine;
if(self->cxx->getMachineType()==bob::learn::libsvm::machine_t::ONE_CLASS)
{
if (subtract && divide) machine = self->cxx->train(Xseq,*PyBlitzArrayCxx_AsBlitz<double,1>(subtract),*PyBlitzArrayCxx_AsBlitz<double,1>(divide));
else machine = self->cxx->train(Xseq);
}
else {
if (subtract && divide) machine = self->cxx->train(Xseq,*PyBlitzArrayCxx_AsBlitz<double,1>(subtract),*PyBlitzArrayCxx_AsBlitz<double,1>(divide));
else machine = self->cxx->train(Xseq);
}
if (subtract && divide) machine = self->cxx->train(Xseq,*PyBlitzArrayCxx_AsBlitz<double,1>(subtract),*PyBlitzArrayCxx_AsBlitz<double,1>(divide));
else machine = self->cxx->train(Xseq);
return PyBobLearnLibsvmMachine_NewFromMachine(machine);
}
catch (std::exception& e) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment