diff --git a/bob/pad/base/algorithm/Predictions.py b/bob/pad/base/algorithm/Predictions.py new file mode 100644 index 0000000000000000000000000000000000000000..6eda214a60d7be1bd3ab69fd3300494345f9c7cd --- /dev/null +++ b/bob/pad/base/algorithm/Predictions.py @@ -0,0 +1,14 @@ +from bob.pad.base.algorithm import Algorithm + + +class Predictions(Algorithm): + """An algorithm that takes the precomputed predictions and uses them for + scoring.""" + + def __init__(self, **kwargs): + super(Predictions, self).__init__( + **kwargs) + + def score(self, predictions): + # Assuming the predictions are the output of a softmax layer + return [predictions[1]] diff --git a/bob/pad/base/algorithm/__init__.py b/bob/pad/base/algorithm/__init__.py index 05d16266a81910299064bca66b092cde61df45e9..be0ba4b1d150769484aae6d9384a01db109df650 100644 --- a/bob/pad/base/algorithm/__init__.py +++ b/bob/pad/base/algorithm/__init__.py @@ -3,7 +3,10 @@ from .SVM import SVM from .OneClassGMM import OneClassGMM from .LogRegr import LogRegr from .SVMCascadePCA import SVMCascadePCA +from .Predictions import Predictions + +# to fix sphinx warnings of not able to find classes, when path is shortened def __appropriate__(*args): """Says object was actually declared here, and not in the import module. Fixing sphinx warnings of not being able to find classes, when path is @@ -28,5 +31,8 @@ __appropriate__( OneClassGMM, LogRegr, SVMCascadePCA, + Predictions, ) + +# gets sphinx autodoc done right - don't remove it __all__ = [_ for _ in dir() if not _.startswith('_')] diff --git a/bob/pad/base/test/test_algorithms.py b/bob/pad/base/test/test_algorithms.py index 555c18550b9cf5c54e95a7d72218702301b24a7a..830c0e6e2f768a7a5f80776f02d99f82a4e9f1a9 100644 --- a/bob/pad/base/test/test_algorithms.py +++ b/bob/pad/base/test/test_algorithms.py @@ -6,9 +6,6 @@ from __future__ import print_function import numpy as np -from bob.io.base.test_utils import datafile -from bob.io.base import load - import bob.io.image # for image loading functionality import bob.bio.video import bob.pad.base @@ -18,8 +15,28 @@ from bob.pad.base.algorithm import OneClassGMM import random -from bob.pad.base.utils import convert_array_to_list_of_frame_cont, convert_list_of_frame_cont_to_array, \ +from bob.pad.base.utils import ( + convert_array_to_list_of_frame_cont, + convert_list_of_frame_cont_to_array, convert_frame_cont_to_array +) + +from bob.pad.base.database import PadFile +from bob.pad.base.algorithm import Predictions +from bob.pad.base import padfile_to_label + + +def test_prediction(): + alg = Predictions() + sample = [0, 1] + assert alg.score(sample)[0] == sample[1] + + +def test_padfile_to_label(): + f = PadFile(client_id='', path='', attack_type=None, file_id=1) + assert padfile_to_label(f) is True, padfile_to_label(f) + f = PadFile(client_id='', path='', attack_type='print', file_id=1) + assert padfile_to_label(f) is False, padfile_to_label(f) def test_video_svm_pad_algorithm(): @@ -144,8 +161,9 @@ def test_video_gmm_pad_algorithm(): assert (np.min(scores_attack) + 38.831260843070098) < 0.000001 assert (np.max(scores_attack) + 5.3633030621521272) < 0.000001 + def test_convert_list_of_frame_cont_to_array(): - + N = 1000 mu = 1 sigma = 1 @@ -155,4 +173,3 @@ def test_convert_list_of_frame_cont_to_array(): assert isinstance(features_array[0], np.ndarray) features_fm = convert_array_to_list_of_frame_cont(real_array) assert isinstance(features_fm[0], bob.bio.video.FrameContainer) - diff --git a/doc/implemented.rst b/doc/implemented.rst index d4d73e6b71912b99e7983ebe56cf0773e9f3e90c..ea0874b9662143181f43fa6aa08c1b141f29dec7 100644 --- a/doc/implemented.rst +++ b/doc/implemented.rst @@ -17,6 +17,7 @@ Only one base class that is presentation attack detection specific, ``Algorithm` .. autosummary:: bob.pad.base.algorithm.Algorithm + bob.pad.base.algorithm.Predictions Implementations ~~~~~~~~~~~~~~~