OneClassGMM.py 2.09 KB
Newer Older
Anjith GEORGE's avatar
Anjith GEORGE committed
1
#!/usr/bin/env python2
2 3
# -*- coding: utf-8 -*-
"""
Anjith GEORGE's avatar
Anjith GEORGE committed
4
@author: Anjith George
5 6 7 8
"""

# ==============================================================================

Anjith GEORGE's avatar
Anjith GEORGE committed
9
from .ScikitClassifier import ScikitClassifier
10

Anjith GEORGE's avatar
Anjith GEORGE committed
11
from sklearn.mixture import GaussianMixture
12

Anjith GEORGE's avatar
Anjith GEORGE committed
13
from sklearn.preprocessing import StandardScaler
14 15


Anjith GEORGE's avatar
Anjith GEORGE committed
16
class OneClassGMM(ScikitClassifier):
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
    """
    This class is designed to train a OneClassGMM based PAD system. The OneClassGMM is trained
    using data of one class (real class) only. The procedure is the following:

    1. First, the training data is mean-std normalized using mean and std of the
       real class only.

    2. Second, the OneClassGMM with ``n_components`` Gaussians is trained using samples
       of the real class.

    3. The input features are next classified using pre-trained OneClassGMM machine.

    **Parameters:**

    ``n_components`` : :py:class:`int`
        Number of Gaussians in the OneClassGMM. Default: 1 .

    ``random_state`` : :py:class:`int`
        A seed for the random number generator used in the initialization of
36
        the OneClassGMM. Default: 3 .
37 38 39 40 41 42 43 44 45

    ``frame_level_scores_flag`` : :py:class:`bool`
        Return scores for each frame individually if True. Otherwise, return a
        single score per video. Default: False.
    """

    def __init__(self,
                 n_components=1,
                 random_state=3,
46 47 48 49
                 frame_level_scores_flag=False,
                 covariance_type='full',
                 reg_covar=1e-06,
                 ):
50

Anjith GEORGE's avatar
Anjith GEORGE committed
51 52 53 54 55 56 57 58 59
        ScikitClassifier.__init__(self,
                                  clf=GaussianMixture(n_components=n_components,
                                                      random_state=random_state,
                                                      covariance_type=covariance_type,
                                                      reg_covar=reg_covar),
                                  scaler=StandardScaler(),
                                  frame_level_scores_flag=frame_level_scores_flag,
                                  norm_on_bonafide=True,
                                  one_class=True)