Commit af33fa09 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira

Merge pull request #5 from acostapazo/master

Added ONE_CLASS support
parents 3c1d4084 0b720f7f
......@@ -111,15 +111,29 @@ static boost::shared_ptr<svm_problem> data2problem
std::ptr_fun(delete_problem));
//choose labels.
if ((data.size() <= 1) | (data.size() > 16)) {
boost::format m("Only supports SVMs for binary or multi-class classification problems (up to 16 classes). You passed me a list of %d arraysets.");
m % data.size();
throw std::runtime_error(m.str());
if(param.svm_type==ONE_CLASS)
{
if ((data.size() != 1)) {
boost::format m("Only support a singular entry for one class. Your are training ONE_CLASS svm classifier. You passed me a list of %d arraysets.");
m % data.size();
throw std::runtime_error(m.str());
}
}
else {
if ((data.size() <= 1) | (data.size() > 16)) {
boost::format m("Only supports SVMs for binary or multi-class classification problems (up to 16 classes). You passed me a list of %d arraysets.");
m % data.size();
throw std::runtime_error(m.str());
}
}
std::vector<double> labels;
labels.reserve(data.size());
if (data.size() == 2) {
if (data.size() == 1) {
//oc-svm only support one class.
labels.push_back(+1.);
}
else if (data.size() == 2) {
//keep libsvm ordering
labels.push_back(+1.);
labels.push_back(-1.);
......@@ -266,9 +280,11 @@ bob::learn::libsvm::Machine* bob::learn::libsvm::Trainer::train
bob::learn::libsvm::Machine* bob::learn::libsvm::Trainer::train
(const std::vector<blitz::Array<double,2> >& data) const {
int n_features = data[0].extent(blitz::secondDim);
blitz::Array<double,1> sub(n_features);
sub = 0.;
blitz::Array<double,1> div(n_features);
div = 1.;
return train(data, sub, div);
}
svm_type one_class
kernel_type rbf
gamma 0.0769231
nr_class 2
total_sv 63
rho 24.8884
SV
1 1:0.708333 2:1 3:1 4:-0.320755 5:-0.105023 6:-1 7:1 8:-0.419847 9:-1 10:-0.225806 12:1 13:-1
1 1:0.166667 2:1 3:-0.333333 4:-0.433962 5:-0.383562 6:-1 7:-1 8:0.0687023 9:-1 10:-0.903226 11:-1 12:-1 13:1
1 1:0.416667 2:-1 3:1 4:0.0566038 5:0.283105 6:-1 7:1 8:0.267176 9:-1 10:0.290323 12:1 13:1
1 1:0.333333 2:1 3:-1 4:-0.245283 5:-0.506849 6:-1 7:-1 8:0.129771 9:-1 10:-0.16129 12:0.333333 13:-1
1 2:1 3:1 4:-0.132075 5:-0.648402 6:1 7:1 8:0.282443 9:1 11:1 12:-1 13:1
1 1:0.25 2:1 3:1 4:0.433962 5:-0.086758 6:-1 7:1 8:0.0534351 9:1 10:0.0967742 11:1 12:-1 13:1
1 1:-0.208333 2:1 3:1 4:-0.320755 5:-0.406393 6:1 7:1 8:0.206107 9:1 10:-1 11:-1 12:0.333333 13:1
1 1:0.25 2:1 3:-1 4:0.245283 5:-0.328767 6:-1 7:1 8:-0.175573 9:-1 10:-1 11:-1 12:-1 13:-1
1 1:-0.541667 2:1 3:1 4:0.0943396 5:-0.557078 6:-1 7:-1 8:0.679389 9:-1 10:-1 11:-1 12:-1 13:1
0.9402393203208671 1:0.25 2:1 3:0.333333 4:-0.396226 5:-0.579909 6:1 7:-1 8:-0.0381679 9:-1 10:-0.290323 12:-0.333333 13:0.5
1 1:-0.166667 2:1 3:0.333333 4:-0.54717 5:-0.894977 6:-1 7:1 8:-0.160305 9:-1 10:-0.741935 11:-1 12:1 13:-1
1 1:-0.375 2:1 3:1 4:-0.698113 5:-0.675799 6:-1 7:1 8:0.618321 9:-1 10:-1 11:-1 12:-0.333333 13:-1
1 1:0.541667 2:1 3:-0.333333 4:0.245283 5:-0.452055 6:-1 7:-1 8:-0.251908 9:1 10:-1 12:1 13:0.5
1 1:0.5 2:-1 3:1 4:0.0566038 5:-0.547945 6:-1 7:1 8:-0.343511 9:-1 10:-0.677419 12:1 13:1
1 1:0.25 2:-1 3:1 4:0.509434 5:-0.438356 6:-1 7:-1 8:0.0992366 9:1 10:-1 12:-1 13:-1
1 1:-0.0833333 2:-1 3:1 4:-0.320755 5:-0.182648 6:-1 7:-1 8:0.0839695 9:1 10:-0.612903 12:-1 13:1
1 1:0.208333 2:-1 3:-0.333333 4:-0.207547 5:-0.118721 6:1 7:1 8:0.236641 9:-1 10:-1 11:-1 12:0.333333 13:-1
1 1:-0.25 2:1 3:0.333333 4:-0.735849 5:-0.465753 6:-1 7:-1 8:0.236641 9:-1 10:-1 11:-1 12:-1 13:-1
1 1:-0.333333 2:1 3:1 4:-0.0943396 5:-0.164384 6:-1 7:1 8:0.160305 9:1 10:-1 12:1 13:1
0.5282590018257561 1:-0.75 2:1 3:1 4:-0.509434 5:-0.671233 6:-1 7:-1 8:-0.0992366 9:1 10:-0.483871 12:-1 13:1
1 1:0.333333 2:-1 3:1 4:-0.320755 5:-0.0684932 6:-1 7:1 8:0.496183 9:-1 10:-1 11:-1 12:-1 13:-1
1 1:0.583333 2:1 3:1 4:-0.509434 5:-0.493151 6:-1 7:-1 8:-1 9:-1 10:-0.677419 12:-1 13:-1
1 1:0.166667 2:1 3:1 4:0.339623 5:-0.255708 6:1 7:1 8:-0.19084 9:-1 10:-0.677419 12:1 13:1
0.1642833781063557 1:0.291667 2:-1 3:1 4:0.0566038 5:-0.39726 6:-1 7:1 8:0.312977 9:-1 10:-0.16129 12:0.333333 13:1
1 1:0.0833333 2:-1 3:1 4:0.622642 5:-0.0821918 6:-1 8:-0.29771 9:1 10:0.0967742 12:-1 13:-1
1 1:0.291667 2:-1 3:1 4:0.207547 5:-0.182648 6:-1 7:1 8:0.374046 9:-1 10:-1 11:-1 12:-1 13:-1
1 1:0.125 2:-1 3:1 4:1 5:-0.260274 6:1 7:1 8:-0.0534351 9:1 10:0.290323 11:1 12:0.333333 13:1
1 1:0.125 2:1 3:1 4:-0.320755 5:-0.283105 6:1 7:1 8:-0.51145 9:1 10:-0.483871 11:1 12:-1 13:1
1 1:-0.166667 2:1 3:0.333333 4:-0.509434 5:-0.716895 6:-1 7:-1 8:0.0381679 9:-1 10:-0.354839 12:1 13:1
1 1:0.291667 2:1 3:1 4:-0.566038 5:-0.525114 6:1 7:-1 8:0.358779 9:1 10:-0.548387 11:-1 12:0.333333 13:1
1 1:0.416667 2:-1 3:1 4:-0.735849 5:-0.347032 6:-1 7:-1 8:0.496183 9:1 10:-0.419355 12:0.333333 13:-1
0.7969957449775902 1:0.541667 2:1 3:1 4:-0.660377 5:-0.607306 6:-1 7:1 8:-0.0687023 9:1 10:-0.967742 11:-1 12:-0.333333 13:-1
1 1:0.458333 2:1 3:1 4:-0.509434 5:-0.452055 6:-1 7:1 8:-0.618321 9:1 10:-0.290323 11:1 12:-0.333333 13:-1
1 1:0.125 2:1 3:1 4:-0.415094 5:-0.438356 6:1 7:1 8:0.114504 9:1 10:-0.612903 12:-0.333333 13:-1
1 1:0.0416667 2:1 3:-0.333333 4:0.849057 5:-0.283105 6:-1 7:1 8:0.89313 9:-1 10:-1 11:-1 12:-0.333333 13:1
0.7413132872854584 1:-0.0416667 2:1 3:1 4:-0.660377 5:-0.525114 6:-1 7:-1 8:0.358779 9:-1 10:-1 11:-1 12:-0.333333 13:-1
1 1:0.0833333 2:1 3:1 4:-0.132075 5:-0.584475 6:-1 7:-1 8:-0.389313 9:1 10:0.806452 11:1 12:-1 13:1
1 1:0.541667 2:-1 3:1 4:0.584906 5:-0.534247 6:1 7:-1 8:0.435115 9:1 10:-0.677419 12:0.333333 13:1
1 1:-0.625 2:1 3:-1 4:-0.509434 5:-0.520548 6:-1 7:-1 8:0.694656 9:1 10:0.225806 12:-1 13:1
1 1:0.375 2:-1 3:1 4:0.0566038 5:-0.461187 6:-1 7:-1 8:0.267176 9:1 10:-0.548387 12:-1 13:-1
1 1:0.5 2:1 3:-1 4:-0.169811 5:-0.287671 6:1 7:1 8:0.572519 9:-1 10:-0.548387 12:-0.333333 13:-1
1 1:0.375 2:-1 3:1 4:-0.169811 5:-0.232877 6:1 7:-1 8:-0.465649 9:-1 10:-0.387097 12:1 13:-1
1 1:-0.0833333 2:1 3:1 4:-0.132075 5:-0.214612 6:-1 7:-1 8:-0.221374 9:1 10:0.354839 12:1 13:1
1 1:-0.291667 2:1 3:0.333333 4:0.0566038 5:-0.520548 6:-1 7:-1 8:0.160305 9:-1 10:0.16129 12:-1 13:-1
1 1:0.583333 2:1 3:1 4:-0.415094 5:-0.415525 6:1 7:-1 8:0.40458 9:-1 10:-0.935484 12:0.333333 13:1
1 1:0.125 2:-1 3:1 4:-0.245283 5:0.292237 6:-1 7:1 8:0.206107 9:1 10:-0.387097 12:0.333333 13:1
1 1:-0.5 2:1 3:1 4:-0.698113 5:-0.789954 6:-1 7:1 8:0.328244 9:-1 10:-1 11:-1 12:-1 13:1
1 1:0.708333 2:1 3:1 4:-0.0377358 5:-0.780822 6:-1 7:-1 8:-0.175573 9:1 10:-0.16129 11:1 12:-1 13:1
1 1:-0.75 2:1 3:1 4:-0.396226 5:-0.287671 6:-1 7:1 8:0.29771 9:1 10:-1 11:-1 12:-1 13:1
1 1:1 2:1 3:1 4:-0.415094 5:-0.187215 6:-1 7:1 8:0.389313 9:1 10:-1 11:-1 12:1 13:-1
0.3930955814189303 1:-0.0833333 2:1 3:1 4:-0.132075 5:-0.210046 6:-1 7:-1 8:0.557252 9:1 10:-0.483871 11:-1 12:-1 13:1
1 1:0.25 2:1 3:-1 4:0.433962 5:-0.260274 6:-1 7:1 8:0.343511 9:-1 10:-0.935484 12:-1 13:1
1 1:0.416667 2:1 3:1 4:-0.320755 5:-0.0684932 6:1 7:1 8:-0.0687023 9:1 10:-0.419355 11:-1 12:1 13:1
1 1:0.375 2:-1 3:0.333333 4:-0.320755 5:-0.374429 6:-1 7:-1 8:-0.603053 9:-1 10:-0.612903 12:-0.333333 13:1
1 1:-0.416667 2:-1 3:1 4:-0.283019 5:-0.0182648 6:1 7:1 8:-0.00763359 9:1 10:-0.0322581 12:-1 13:1
1 1:0.333333 2:-1 3:1 4:-0.0377358 5:-0.173516 6:-1 7:1 8:0.145038 9:1 10:-0.677419 12:-1 13:1
1 1:0.375 2:-1 3:1 4:0.245283 5:-0.826484 6:-1 7:1 8:0.129771 9:-1 10:1 11:1 12:1 13:1
1 1:0.625 2:1 3:0.333333 4:0.622642 5:-0.324201 6:1 7:1 8:0.206107 9:1 10:-0.483871 12:-1 13:1
1 1:0.375 2:-1 3:1 4:-0.132075 5:-0.351598 6:-1 7:1 8:0.358779 9:-1 10:0.16129 11:1 12:0.333333 13:-1
1 1:0.458333 2:1 3:0.333333 4:-0.132075 5:-0.0456621 6:-1 7:-1 8:0.328244 9:-1 10:-1 11:-1 12:-1 13:-1
0.4358136860650423 1:0.208333 2:1 3:-0.333333 4:-0.509434 5:-0.278539 6:-1 7:1 8:0.358779 9:-1 10:-0.419355 12:-1 13:-1
1 1:-0.208333 2:1 3:-0.333333 4:-0.698113 5:-0.52968 6:-1 7:-1 8:0.480916 9:-1 10:-0.677419 11:1 12:-1 13:1
1 1:0.583333 2:1 3:1 4:0.245283 5:-0.269406 6:-1 7:1 8:-0.435115 9:1 10:-0.516129 12:1 13:-1
......@@ -28,6 +28,7 @@ def tempname(suffix, prefix='bobtest_'):
return name
TEST_MACHINE_NO_PROBS = F('heart_no_probs.svmmodel')
TEST_MACHINE_ONE_CLASS = F('heart_one_class.svmmodel')
HEART_DATA = F('heart.svmdata') #13 inputs
HEART_MACHINE = F('heart.svmmodel') #supports probabilities
......@@ -176,3 +177,41 @@ def test_training_with_probability():
curr_scores = numpy.array(curr_scores)
prev_scores = numpy.array(prev_scores)
#assert numpy.all(abs(curr_scores-prev_scores) < 1e-8)
def test_training_one_class():
# For this example I'm using an OC-SVM file because of convinience. You only
# need to make sure you can gather the input into 2D singular arrays in which
# the only array represents data from one class and each line on such array
# contains a sample.
f = File(HEART_DATA)
labels, data = f.read_all()
pos = numpy.vstack([k for i,k in enumerate(data) if labels[i] > 0])
# Data is also pre-scaled so features remain in the range between -1 and
# +1. libsvm, apparently, suggests you do that for all features. Our
# bindings to libsvm do not include scaling. If you want to implement that
# generically, please do it.
trainer = Trainer(machine_type='ONE_CLASS')
machine = trainer.train([pos]) #ordering only affects labels
previous = Machine(TEST_MACHINE_ONE_CLASS)
nose.tools.eq_(machine.machine_type, previous.machine_type)
nose.tools.eq_(machine.kernel_type, previous.kernel_type)
nose.tools.eq_(machine.gamma, previous.gamma)
nose.tools.eq_(machine.shape, previous.shape)
assert numpy.all(abs(machine.input_subtract - previous.input_subtract) < 1e-8)
assert numpy.all(abs(machine.input_divide - previous.input_divide) < 1e-8)
curr_label = machine.predict_class(data)
prev_label = previous.predict_class(data)
assert numpy.array_equal(curr_label, prev_label)
curr_labels, curr_scores = machine.predict_class_and_scores(data)
prev_labels, prev_scores = previous.predict_class_and_scores(data)
assert numpy.array_equal(curr_labels, prev_labels)
curr_scores = numpy.array(curr_scores)
prev_scores = numpy.array(prev_scores)
assert numpy.all(abs(curr_scores - prev_scores) < 1e-8)
......@@ -38,7 +38,7 @@ machine_type, str\n\
\n\
* ``'C_SVC'`` (the default)\n\
* ``'NU_SVC'``\n\
* ``'ONE_CLASS'`` (**unsupported**)\n\
* ``'ONE_CLASS'`` \n\
* ``'EPSILON_SVR'`` (**unsupported** regression)\n\
* ``'NU_SVR'`` (**unsupported** regression)\n\
\n\
......@@ -634,15 +634,23 @@ static PyObject* PyBobLearnLibsvmTrainer_train
}
Xseq_.push_back(make_safe(bz)); ///< prevents data deletion
Xseq.push_back(*PyBlitzArrayCxx_AsBlitz<double,2>(bz)); ///< only a view!
Xseq.push_back(*PyBlitzArrayCxx_AsBlitz<double,2>(bz)); ///< only a view!
}
if (PyErr_Occurred()) return 0;
if (Xseq.size() < 2) {
// To Review this checks. It is probably that we have to create differents chechs when machine type is ONE_CLASS
if ( (Xseq.size() < 2) && (self->cxx->getMachineType()!=bob::learn::libsvm::machine_t::ONE_CLASS) ) {
PyErr_Format(PyExc_RuntimeError, "`%s' requires an iterable for parameter `X' leading to, at least, two entries (representing two classes), but you have passed something that has only %" PY_FORMAT_SIZE_T "d entries", Py_TYPE(self)->tp_name, Xseq.size());
return 0;
}
if ( (Xseq.size() < 1) && (self->cxx->getMachineType()==bob::learn::libsvm::machine_t::ONE_CLASS) ) {
PyErr_Format(PyExc_RuntimeError, "`%s' requires an iterable for parameter `X' leading to, at least, one entry (representing one class), but you have passed something that has only %" PY_FORMAT_SIZE_T "d entries", Py_TYPE(self)->tp_name, Xseq.size());
return 0;
}
if (subtract && !divide) {
PyErr_Format(PyExc_RuntimeError, "`%s' requires you provide both `subtract' and `divide' or neither, but you provided only `subtract'", Py_TYPE(self)->tp_name);
......@@ -663,16 +671,19 @@ static PyObject* PyBobLearnLibsvmTrainer_train
}
/** all basic checks are done, can call the machine now **/
//std::cout << "all basic checks are done, can call the machine now..." << std::endl;
try {
bob::learn::libsvm::Machine* machine;
if (subtract && divide) {
machine = self->cxx->train(Xseq,
*PyBlitzArrayCxx_AsBlitz<double,1>(subtract),
*PyBlitzArrayCxx_AsBlitz<double,1>(divide)
);
if(self->cxx->getMachineType()==bob::learn::libsvm::machine_t::ONE_CLASS)
{
if (subtract && divide) machine = self->cxx->train(Xseq,*PyBlitzArrayCxx_AsBlitz<double,1>(subtract),*PyBlitzArrayCxx_AsBlitz<double,1>(divide));
else machine = self->cxx->train(Xseq);
}
else {
machine = self->cxx->train(Xseq);
else {
if (subtract && divide) machine = self->cxx->train(Xseq,*PyBlitzArrayCxx_AsBlitz<double,1>(subtract),*PyBlitzArrayCxx_AsBlitz<double,1>(divide));
else machine = self->cxx->train(Xseq);
}
return PyBobLearnLibsvmMachine_NewFromMachine(machine);
}
......
......@@ -54,7 +54,6 @@ bob::learn::libsvm::machine_t PyBobLearnLibsvm_CStringAsMachineType(const char*
return bob::learn::libsvm::NU_SVC;
}
else if (s_ == "ONE_CLASS") {
PyErr_Format(PyExc_NotImplementedError, "support for `%s' is not currently implemented by these bindings - choose from %s", s, available);
return bob::learn::libsvm::ONE_CLASS;
}
else if (s_ == "EPSILON_SVR") {
......
......@@ -152,6 +152,40 @@ instead, this can be set before calling the
>>> trainer.kernel_type = 'LINEAR'
One Class SVM
=============
On the other hand, the package allows you to train a One Class Support Vector Machine. For training this kind of classifier take into account the following example.
.. doctest::
:options: +NORMALIZE_WHITESPACE
>>> oc_pos = 0.4 * numpy.random.randn(100, 2).astype(numpy.float64)
>>> oc_data = [oc_pos]
>>> print(oc_data) # doctest: +SKIP
As the above example, an SVM [1]_ for one class problem can be trained easily using the
:py:class:`bob.learn.libsvm.Trainer` class and selecting the appropiete machine_type (ONE_CLASS).
.. doctest::
:options: +NORMALIZE_WHITESPACE
>>> oc_trainer = bob.learn.libsvm.Trainer(machine_type='ONE_CLASS')
>>> oc_machine = oc_trainer.train(oc_data)
Then, as explained before, a :py:class:`bob.learn.libsvm.Machine` can be used for classify the new entries.
.. doctest::
:options: +NORMALIZE_WHITESPACE
>>> oc_test = 0.4 * numpy.random.randn(20, 2).astype(numpy.float64)
>>> oc_outliers = numpy.random.uniform(low=-4, high=4, size=(20, 2)).astype(numpy.float64)
>>> predicted_label_oc_test = oc_machine(oc_test)
>>> predicted_label_oc_outliers = oc_machine(oc_outliers)
>>> print(predicted_label_oc_test) # doctest: +SKIP
>>> print(predicted_label_oc_outliers) # doctest: +SKIP
Acknowledgements
----------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment