Skip to content
Snippets Groups Projects
Commit b312e6b9 authored by Anjith GEORGE's avatar Anjith GEORGE
Browse files

Added tests for ScikitClassifier

parent f01d2bda
No related branches found
No related tags found
1 merge request!64Scikit wrapper
......@@ -37,6 +37,7 @@ __appropriate__(
SVMCascadePCA,
Predictions,
VideoPredictions,
ScikitClassifier,
MLP,
PadLDA
)
......
......@@ -14,6 +14,7 @@ from bob.pad.base.algorithm import SVM
from bob.pad.base.algorithm import OneClassGMM
from bob.pad.base.algorithm import MLP
from bob.pad.base.algorithm import PadLDA
from bob.pad.base.algorithm import ScikitClassifier
import random
......@@ -219,3 +220,40 @@ def test_LDA():
lda = PadLDA()
lda.train_projector(training_features, '/tmp/lda.hdf5')
assert lda.machine.shape == (2, 1)
def test_ScikitClassifier():
random.seed(7)
N = 20000
mu = 1
sigma = 1
real_array = np.transpose(
np.vstack([[random.gauss(mu, sigma) for _ in range(N)],
[random.gauss(mu, sigma) for _ in range(N)]]))
mu = 5
sigma = 1
attack_array = np.transpose(
np.vstack([[random.gauss(mu, sigma) for _ in range(N)],
[random.gauss(mu, sigma) for _ in range(N)]]))
training_features = [real_array, attack_array]
from sklearn.preprocessing import StandardScaler
from sklearn.mixture import GaussianMixture
_scaler = StandardScaler()
_clf = GaussianMixture(n_components=10, covariance_type='full')
sk = ScikitClassifier(clf=_clf, scaler=_scaler, frame_level_scores_flag=False, one_class=True)
sk.train_projector(training_features, '/tmp/sk.hdf5')
# Model path `/tmp/sk_skmodel.obj`
# Scaler path `/tmp/sk_scaler.obj`
assert sk.clf.n_components==10
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment