Commit 94b2413a authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[algorithms] added LDA and MLP for classification (work with numpy arrays only)

parent d79c8ab9
Pipeline #17258 failed with stage
in 19 minutes and 42 seconds
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
import numpy
from bob.pad.base.algorithm import Algorithm
import bob.learn.mlp
import bob.io.base
class MLP(Algorithm):
"""
This class interfaces an MLP classifier used for PAD
"""
def __init__(self, hidden_units=(10, 10), max_iter=1000, **kwargs):
Algorithm.__init__(self,
performs_projection=True,
requires_projector_training=True,
**kwargs)
self.hidden_units = hidden_units
self.max_iter = max_iter
self.mlp = None
def train_projector(self, training_features, projector_file):
"""
Trains the MLP
**Parameters**
training_features:
"""
# training_features[0] - training features for the REAL class.
# training_features[1] - training features for the ATTACK class.
# The data
batch_size = len(training_features[0]) + len(training_features[1])
print(batch_size)
label_real = numpy.ones((len(training_features[0]), 1), dtype='float64')
label_attack = numpy.zeros((len(training_features[1]), 1), dtype='float64')
real = numpy.array(training_features[0])
attack = numpy.array(training_features[1])
X = numpy.vstack([real, attack])
Y = numpy.vstack([label_real, label_attack])
# The machine
input_dim = real.shape[1]
shape = []
shape.append(input_dim)
for i in range(len(self.hidden_units)):
shape.append(self.hidden_units[i])
shape.append(1)
shape = tuple(shape)
self.mlp = bob.learn.mlp.Machine(shape)
self.mlp.output_activation = bob.learn.activation.Logistic()
self.mlp.randomize()
# The trainer
trainer = bob.learn.mlp.BackProp(batch_size, bob.learn.mlp.CrossEntropyLoss(self.mlp.output_activation), self.mlp, train_biases=True)
n_iter = 0
previous_cost = 0
current_cost = 1
precision = 0.001
while (n_iter < self.max_iter) or (abs(previous_cost - current_cost) < precision):
previous_cost = current_cost
trainer.train(self.mlp, X, Y)
current_cost = trainer.cost(self.mlp, X, Y)
n_iter += 1
print("Iteration {} -> cost = {} (previous = {})".format(n_iter, trainer.cost(self.mlp, X, Y), previous_cost))
f = bob.io.base.HDF5File(projector_file, 'w')
self.mlp.save(f)
def project(self, feature):
"""
Project the given feature
"""
return self.mlp(feature)
def score(self, toscore):
return [toscore[0]]
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
import numpy
from bob.bio.base.algorithm import LDA
class PadLDA(LDA):
"""
This class is a wrapper for bob.bio.base.algorithm.LDA,
to be used in a PAD context.
**Parameters**
"""
def __init__(self,
lda_subspace_dimension = None, # if set, the LDA subspace will be truncated to the given number of dimensions; by default it is limited to the number of classes in the training set
pca_subspace_dimension = None, # if set, a PCA subspace truncation is performed before applying LDA; might be integral or float
use_pinv = False,
**kwargs
):
super(PadLDA, self).__init__(
lda_subspace_dimension = lda_subspace_dimension,
pca_subspace_dimension = pca_subspace_dimension,
use_pinv = use_pinv,
**kwargs
)
def read_toscore_object(self, toscore_object_file):
"""read_toscore_object(toscore_object_file) -> toscore_object
Reads the toscore_object feature from a file.
By default, the toscore_object feature is identical to the projected feature.
Hence, this base class implementation simply calls :py:meth:`read_feature`.
If your algorithm requires different behavior, please overwrite this function.
**Parameters:**
toscore_object_file : str or :py:class:`bob.io.base.HDF5File`
The file open for reading, or the file name to read from.
**Returns:**
toscore_object : object
The toscore_object that was read from file.
"""
return self.read_feature(toscore_object_file)
def score(self, toscore):
return [toscore[0]]
......@@ -3,6 +3,8 @@ from .SVM import SVM
from .OneClassGMM import OneClassGMM
from .LogRegr import LogRegr
from .SVMCascadePCA import SVMCascadePCA
from .MLP import MLP
from .PadLDA import PadLDA
def __appropriate__(*args):
"""Says object was actually declared here, and not in the import module.
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment