Created and binded a method to read the number of support vectors per class....

Created and binded a method to read the number of support vectors per class. Also added unit tests for that
parent 27874dee
......@@ -194,6 +194,8 @@ size_t bob::learn::libsvm::Machine::numberOfClasses() const {
return svm_get_nr_class(m_model.get());
}
int bob::learn::libsvm::Machine::classLabel(size_t i) const {
if (i >= (size_t)svm_get_nr_class(m_model.get())) {
......@@ -205,6 +207,19 @@ int bob::learn::libsvm::Machine::classLabel(size_t i) const {
}
int bob::learn::libsvm::Machine::classNSupportVectors(size_t i) const {
if (i >= (size_t)svm_get_nr_class(m_model.get())) {
boost::format s("request data for the class %d in SVM with %d classes is not legal");
s % (int)i % svm_get_nr_class(m_model.get());
throw std::runtime_error(s.str());
}
return m_model->nSV[i];
}
bob::learn::libsvm::machine_t bob::learn::libsvm::Machine::machineType() const {
return (machine_t)svm_get_svm_type(m_model.get());
}
......
......@@ -154,12 +154,20 @@ namespace bob { namespace learn { namespace libsvm {
*/
size_t numberOfClasses() const;
/**
* Returns the class label (as stored inside the svm_model object) for a
* given class 'i'.
*/
int classLabel(size_t i) const;
/**
* Returns the number of suport vectors for a
* given class 'i'.
*/
int classNSupportVectors(size_t i) const;
/**
* SVM type
*/
......
......@@ -274,6 +274,22 @@ static PyObject* PyBobLearnLibsvmMachine_getLabels
return retval;
}
PyDoc_STRVAR(s_n_support_vectors_str, "n_support_vectors");
PyDoc_STRVAR(s_n_support_vectors_doc, "Will output the number of support vectors per class");
static PyObject* PyBobLearnLibsvmMachine_getNSupportVectors
(PyBobLearnLibsvmMachineObject* self, void* /*closure*/) {
PyObject* retval = PyList_New(self->cxx->numberOfClasses());
for (size_t k=0; k<self->cxx->numberOfClasses(); ++k) {
PyList_SET_ITEM(retval, k, Py_BuildValue("i", self->cxx->classNSupportVectors(k)));
}
return retval;
}
PyDoc_STRVAR(s_machine_type_str, "machine_type");
PyDoc_STRVAR(s_machine_type_doc, "The type of SVM machine contained");
......@@ -360,6 +376,15 @@ static PyGetSetDef PyBobLearnLibsvmMachine_getseters[] = {
s_labels_doc,
0
},
{
s_n_support_vectors_str,
(getter)PyBobLearnLibsvmMachine_getNSupportVectors,
0,
s_n_support_vectors_doc,
0
},
{
s_machine_type_str,
(getter)PyBobLearnLibsvmMachine_getMachineType,
......
......@@ -72,6 +72,7 @@ def test_can_load():
machine = Machine(HEART_MACHINE)
nose.tools.eq_(machine.shape, (13,1))
nose.tools.eq_(machine.n_support_vectors, [64,68])
nose.tools.eq_(machine.kernel_type, 'RBF')
nose.tools.eq_(machine.machine_type, 'C_SVC')
nose.tools.eq_(len(machine.labels), 2)
......@@ -89,6 +90,7 @@ def test_can_save():
# make sure that the save machine is the same as before
machine = Machine(tmp)
nose.tools.eq_(machine.shape, (13,1))
nose.tools.eq_(machine.n_support_vectors, [64,68])
nose.tools.eq_(machine.kernel_type, 'RBF')
nose.tools.eq_(machine.machine_type, 'C_SVC')
nose.tools.eq_(len(machine.labels), 2)
......@@ -108,6 +110,7 @@ def test_can_save_hdf5():
# make sure that the save machine is the same as before
machine = Machine(bob.io.base.HDF5File(tmp))
nose.tools.eq_(machine.shape, (13,1))
nose.tools.eq_(machine.n_support_vectors, [64,68])
nose.tools.eq_(machine.kernel_type, 'RBF')
nose.tools.eq_(machine.machine_type, 'C_SVC')
nose.tools.eq_(len(machine.labels), 2)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment