Skip to content
Snippets Groups Projects
Commit 15c88bed authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'predictions' into 'master'

Add a predictions algorithm useful for deeplearning

See merge request !38
parents 9d65b96a c4a7ac71
No related branches found
No related tags found
1 merge request!38Add a predictions algorithm useful for deeplearning
Pipeline #
from bob.pad.base.algorithm import Algorithm
class Predictions(Algorithm):
"""An algorithm that takes the precomputed predictions and uses them for
scoring."""
def __init__(self, **kwargs):
super(Predictions, self).__init__(
**kwargs)
def score(self, predictions):
# Assuming the predictions are the output of a softmax layer
return [predictions[1]]
...@@ -3,7 +3,10 @@ from .SVM import SVM ...@@ -3,7 +3,10 @@ from .SVM import SVM
from .OneClassGMM import OneClassGMM from .OneClassGMM import OneClassGMM
from .LogRegr import LogRegr from .LogRegr import LogRegr
from .SVMCascadePCA import SVMCascadePCA from .SVMCascadePCA import SVMCascadePCA
from .Predictions import Predictions
# to fix sphinx warnings of not able to find classes, when path is shortened
def __appropriate__(*args): def __appropriate__(*args):
"""Says object was actually declared here, and not in the import module. """Says object was actually declared here, and not in the import module.
Fixing sphinx warnings of not being able to find classes, when path is Fixing sphinx warnings of not being able to find classes, when path is
...@@ -28,5 +31,8 @@ __appropriate__( ...@@ -28,5 +31,8 @@ __appropriate__(
OneClassGMM, OneClassGMM,
LogRegr, LogRegr,
SVMCascadePCA, SVMCascadePCA,
Predictions,
) )
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')] __all__ = [_ for _ in dir() if not _.startswith('_')]
...@@ -6,9 +6,6 @@ from __future__ import print_function ...@@ -6,9 +6,6 @@ from __future__ import print_function
import numpy as np import numpy as np
from bob.io.base.test_utils import datafile
from bob.io.base import load
import bob.io.image # for image loading functionality import bob.io.image # for image loading functionality
import bob.bio.video import bob.bio.video
import bob.pad.base import bob.pad.base
...@@ -18,8 +15,28 @@ from bob.pad.base.algorithm import OneClassGMM ...@@ -18,8 +15,28 @@ from bob.pad.base.algorithm import OneClassGMM
import random import random
from bob.pad.base.utils import convert_array_to_list_of_frame_cont, convert_list_of_frame_cont_to_array, \ from bob.pad.base.utils import (
convert_array_to_list_of_frame_cont,
convert_list_of_frame_cont_to_array,
convert_frame_cont_to_array convert_frame_cont_to_array
)
from bob.pad.base.database import PadFile
from bob.pad.base.algorithm import Predictions
from bob.pad.base import padfile_to_label
def test_prediction():
alg = Predictions()
sample = [0, 1]
assert alg.score(sample)[0] == sample[1]
def test_padfile_to_label():
f = PadFile(client_id='', path='', attack_type=None, file_id=1)
assert padfile_to_label(f) is True, padfile_to_label(f)
f = PadFile(client_id='', path='', attack_type='print', file_id=1)
assert padfile_to_label(f) is False, padfile_to_label(f)
def test_video_svm_pad_algorithm(): def test_video_svm_pad_algorithm():
...@@ -144,8 +161,9 @@ def test_video_gmm_pad_algorithm(): ...@@ -144,8 +161,9 @@ def test_video_gmm_pad_algorithm():
assert (np.min(scores_attack) + 38.831260843070098) < 0.000001 assert (np.min(scores_attack) + 38.831260843070098) < 0.000001
assert (np.max(scores_attack) + 5.3633030621521272) < 0.000001 assert (np.max(scores_attack) + 5.3633030621521272) < 0.000001
def test_convert_list_of_frame_cont_to_array(): def test_convert_list_of_frame_cont_to_array():
N = 1000 N = 1000
mu = 1 mu = 1
sigma = 1 sigma = 1
...@@ -155,4 +173,3 @@ def test_convert_list_of_frame_cont_to_array(): ...@@ -155,4 +173,3 @@ def test_convert_list_of_frame_cont_to_array():
assert isinstance(features_array[0], np.ndarray) assert isinstance(features_array[0], np.ndarray)
features_fm = convert_array_to_list_of_frame_cont(real_array) features_fm = convert_array_to_list_of_frame_cont(real_array)
assert isinstance(features_fm[0], bob.bio.video.FrameContainer) assert isinstance(features_fm[0], bob.bio.video.FrameContainer)
...@@ -17,6 +17,7 @@ Only one base class that is presentation attack detection specific, ``Algorithm` ...@@ -17,6 +17,7 @@ Only one base class that is presentation attack detection specific, ``Algorithm`
.. autosummary:: .. autosummary::
bob.pad.base.algorithm.Algorithm bob.pad.base.algorithm.Algorithm
bob.pad.base.algorithm.Predictions
Implementations Implementations
~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment