diff --git a/bob/pad/base/algorithm/OCSVM.py b/bob/pad/base/algorithm/OCSVM.py new file mode 100644 index 0000000000000000000000000000000000000000..a01a8584b065b27fdf24e7963c8575d095a1c2cc --- /dev/null +++ b/bob/pad/base/algorithm/OCSVM.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + +import numpy +from bob.pad.base.algorithm import Algorithm + +import bob.learn.libsvm +import bob.io.base + +class OCSVM(Algorithm): + """ + This class interfaces a One Class SVM classifier used for PAD + + """ + + def __init__(self, rescale=True, nu=0.01, gamma=0.1, **kwargs): + + Algorithm.__init__(self, + performs_projection=True, + requires_projector_training=True, + **kwargs) + + self.rescale = rescale + self.nu = nu + self.gamma = gamma + self.machine = None + self.trainer = bob.learn.libsvm.Trainer(machine_type='ONE_CLASS', kernel_type='RBF', probability=True) + + + def train_projector(self, training_features, projector_file): + """ + Trains the One Class SVM + + **Parameters** + + training_features: + """ + # training_features[0] - training features for the REAL class. + # training_features[1] - training features for the ATTACK class. + + # The data - "positive class only" + pos = numpy.array(training_features[0]) + + if self.rescale: + for i in range(pos.shape[0]): + min_value = numpy.min(pos[i]) + max_value = numpy.max(pos[i]) + pos[i] = ((2 * (pos[i] - min_value))/ (max_value - min_value)) - 1 + + pos = [pos] + + # train + self.machine = self.trainer.train(pos) + + print(self.machine.shape) + + f = bob.io.base.HDF5File(projector_file, 'w') + self.machine.save(f) + + + def project(self, feature): + """ + Project the given feature + """ + if self.rescale: + min_value = numpy.min(feature) + max_value = numpy.max(feature) + feature = ((2 * (feature - min_value))/ (max_value - min_value)) - 1 + + return self.machine(feature) + + def score(self, toscore): + return [toscore[0]] diff --git a/bob/pad/base/algorithm/__init__.py b/bob/pad/base/algorithm/__init__.py index b9da7c2b19eceefa9ecceb2db583e2562f412ed1..86c374c4a3185466b8ec4ec55e207e782ed0b4ea 100644 --- a/bob/pad/base/algorithm/__init__.py +++ b/bob/pad/base/algorithm/__init__.py @@ -5,6 +5,7 @@ from .LogRegr import LogRegr from .SVMCascadePCA import SVMCascadePCA from .MLP import MLP from .PadLDA import PadLDA +from .OCSVM import OCSVM def __appropriate__(*args): """Says object was actually declared here, and not in the import module.