Created and binded a method to read the number of support vectors per class....

Created and binded a method to read the number of support vectors per class. Also added unit tests for that
parent 27874dee
...@@ -194,6 +194,8 @@ size_t bob::learn::libsvm::Machine::numberOfClasses() const { ...@@ -194,6 +194,8 @@ size_t bob::learn::libsvm::Machine::numberOfClasses() const {
return svm_get_nr_class(m_model.get()); return svm_get_nr_class(m_model.get());
} }
int bob::learn::libsvm::Machine::classLabel(size_t i) const { int bob::learn::libsvm::Machine::classLabel(size_t i) const {
if (i >= (size_t)svm_get_nr_class(m_model.get())) { if (i >= (size_t)svm_get_nr_class(m_model.get())) {
...@@ -205,6 +207,19 @@ int bob::learn::libsvm::Machine::classLabel(size_t i) const { ...@@ -205,6 +207,19 @@ int bob::learn::libsvm::Machine::classLabel(size_t i) const {
} }
int bob::learn::libsvm::Machine::classNSupportVectors(size_t i) const {
if (i >= (size_t)svm_get_nr_class(m_model.get())) {
boost::format s("request data for the class %d in SVM with %d classes is not legal");
s % (int)i % svm_get_nr_class(m_model.get());
throw std::runtime_error(s.str());
}
return m_model->nSV[i];
}
bob::learn::libsvm::machine_t bob::learn::libsvm::Machine::machineType() const { bob::learn::libsvm::machine_t bob::learn::libsvm::Machine::machineType() const {
return (machine_t)svm_get_svm_type(m_model.get()); return (machine_t)svm_get_svm_type(m_model.get());
} }
......
...@@ -154,12 +154,20 @@ namespace bob { namespace learn { namespace libsvm { ...@@ -154,12 +154,20 @@ namespace bob { namespace learn { namespace libsvm {
*/ */
size_t numberOfClasses() const; size_t numberOfClasses() const;
/** /**
* Returns the class label (as stored inside the svm_model object) for a * Returns the class label (as stored inside the svm_model object) for a
* given class 'i'. * given class 'i'.
*/ */
int classLabel(size_t i) const; int classLabel(size_t i) const;
/**
* Returns the number of suport vectors for a
* given class 'i'.
*/
int classNSupportVectors(size_t i) const;
/** /**
* SVM type * SVM type
*/ */
......
...@@ -274,6 +274,22 @@ static PyObject* PyBobLearnLibsvmMachine_getLabels ...@@ -274,6 +274,22 @@ static PyObject* PyBobLearnLibsvmMachine_getLabels
return retval; return retval;
} }
PyDoc_STRVAR(s_n_support_vectors_str, "n_support_vectors");
PyDoc_STRVAR(s_n_support_vectors_doc, "Will output the number of support vectors per class");
static PyObject* PyBobLearnLibsvmMachine_getNSupportVectors
(PyBobLearnLibsvmMachineObject* self, void* /*closure*/) {
PyObject* retval = PyList_New(self->cxx->numberOfClasses());
for (size_t k=0; k<self->cxx->numberOfClasses(); ++k) {
PyList_SET_ITEM(retval, k, Py_BuildValue("i", self->cxx->classNSupportVectors(k)));
}
return retval;
}
PyDoc_STRVAR(s_machine_type_str, "machine_type"); PyDoc_STRVAR(s_machine_type_str, "machine_type");
PyDoc_STRVAR(s_machine_type_doc, "The type of SVM machine contained"); PyDoc_STRVAR(s_machine_type_doc, "The type of SVM machine contained");
...@@ -360,6 +376,15 @@ static PyGetSetDef PyBobLearnLibsvmMachine_getseters[] = { ...@@ -360,6 +376,15 @@ static PyGetSetDef PyBobLearnLibsvmMachine_getseters[] = {
s_labels_doc, s_labels_doc,
0 0
}, },
{
s_n_support_vectors_str,
(getter)PyBobLearnLibsvmMachine_getNSupportVectors,
0,
s_n_support_vectors_doc,
0
},
{ {
s_machine_type_str, s_machine_type_str,
(getter)PyBobLearnLibsvmMachine_getMachineType, (getter)PyBobLearnLibsvmMachine_getMachineType,
......
...@@ -72,6 +72,7 @@ def test_can_load(): ...@@ -72,6 +72,7 @@ def test_can_load():
machine = Machine(HEART_MACHINE) machine = Machine(HEART_MACHINE)
nose.tools.eq_(machine.shape, (13,1)) nose.tools.eq_(machine.shape, (13,1))
nose.tools.eq_(machine.n_support_vectors, [64,68])
nose.tools.eq_(machine.kernel_type, 'RBF') nose.tools.eq_(machine.kernel_type, 'RBF')
nose.tools.eq_(machine.machine_type, 'C_SVC') nose.tools.eq_(machine.machine_type, 'C_SVC')
nose.tools.eq_(len(machine.labels), 2) nose.tools.eq_(len(machine.labels), 2)
...@@ -89,6 +90,7 @@ def test_can_save(): ...@@ -89,6 +90,7 @@ def test_can_save():
# make sure that the save machine is the same as before # make sure that the save machine is the same as before
machine = Machine(tmp) machine = Machine(tmp)
nose.tools.eq_(machine.shape, (13,1)) nose.tools.eq_(machine.shape, (13,1))
nose.tools.eq_(machine.n_support_vectors, [64,68])
nose.tools.eq_(machine.kernel_type, 'RBF') nose.tools.eq_(machine.kernel_type, 'RBF')
nose.tools.eq_(machine.machine_type, 'C_SVC') nose.tools.eq_(machine.machine_type, 'C_SVC')
nose.tools.eq_(len(machine.labels), 2) nose.tools.eq_(len(machine.labels), 2)
...@@ -108,6 +110,7 @@ def test_can_save_hdf5(): ...@@ -108,6 +110,7 @@ def test_can_save_hdf5():
# make sure that the save machine is the same as before # make sure that the save machine is the same as before
machine = Machine(bob.io.base.HDF5File(tmp)) machine = Machine(bob.io.base.HDF5File(tmp))
nose.tools.eq_(machine.shape, (13,1)) nose.tools.eq_(machine.shape, (13,1))
nose.tools.eq_(machine.n_support_vectors, [64,68])
nose.tools.eq_(machine.kernel_type, 'RBF') nose.tools.eq_(machine.kernel_type, 'RBF')
nose.tools.eq_(machine.machine_type, 'C_SVC') nose.tools.eq_(machine.machine_type, 'C_SVC')
nose.tools.eq_(len(machine.labels), 2) nose.tools.eq_(len(machine.labels), 2)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment