Skip to content
Snippets Groups Projects
Commit e85eb10b authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[algorithm] added my simple implementation of One Class SVM

parent c9b2ecc3
No related branches found
No related tags found
1 merge request!33WIP: added LDA and MLP
Pipeline #
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
import numpy
from bob.pad.base.algorithm import Algorithm
import bob.learn.libsvm
import bob.io.base
class OCSVM(Algorithm):
"""
This class interfaces a One Class SVM classifier used for PAD
"""
def __init__(self, rescale=True, nu=0.01, gamma=0.1, **kwargs):
Algorithm.__init__(self,
performs_projection=True,
requires_projector_training=True,
**kwargs)
self.rescale = rescale
self.nu = nu
self.gamma = gamma
self.machine = None
self.trainer = bob.learn.libsvm.Trainer(machine_type='ONE_CLASS', kernel_type='RBF', probability=True)
def train_projector(self, training_features, projector_file):
"""
Trains the One Class SVM
**Parameters**
training_features:
"""
# training_features[0] - training features for the REAL class.
# training_features[1] - training features for the ATTACK class.
# The data - "positive class only"
pos = numpy.array(training_features[0])
if self.rescale:
for i in range(pos.shape[0]):
min_value = numpy.min(pos[i])
max_value = numpy.max(pos[i])
pos[i] = ((2 * (pos[i] - min_value))/ (max_value - min_value)) - 1
pos = [pos]
# train
self.machine = self.trainer.train(pos)
print(self.machine.shape)
f = bob.io.base.HDF5File(projector_file, 'w')
self.machine.save(f)
def project(self, feature):
"""
Project the given feature
"""
if self.rescale:
min_value = numpy.min(feature)
max_value = numpy.max(feature)
feature = ((2 * (feature - min_value))/ (max_value - min_value)) - 1
return self.machine(feature)
def score(self, toscore):
return [toscore[0]]
...@@ -5,6 +5,7 @@ from .LogRegr import LogRegr ...@@ -5,6 +5,7 @@ from .LogRegr import LogRegr
from .SVMCascadePCA import SVMCascadePCA from .SVMCascadePCA import SVMCascadePCA
from .MLP import MLP from .MLP import MLP
from .PadLDA import PadLDA from .PadLDA import PadLDA
from .OCSVM import OCSVM
def __appropriate__(*args): def __appropriate__(*args):
"""Says object was actually declared here, and not in the import module. """Says object was actually declared here, and not in the import module.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment