From c4a7ac7181152a2ca6aec5dc7d32b51aa039cf69 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Fri, 9 Mar 2018 12:13:13 +0100 Subject: [PATCH] Add a predictions algorithm useful for deeplearning --- bob/pad/base/algorithm/Predictions.py | 14 +++++++++++++ bob/pad/base/algorithm/__init__.py | 6 ++++++ bob/pad/base/test/test_algorithms.py | 29 +++++++++++++++++++++------ doc/implemented.rst | 1 + 4 files changed, 44 insertions(+), 6 deletions(-) create mode 100644 bob/pad/base/algorithm/Predictions.py diff --git a/bob/pad/base/algorithm/Predictions.py b/bob/pad/base/algorithm/Predictions.py new file mode 100644 index 0000000..6eda214 --- /dev/null +++ b/bob/pad/base/algorithm/Predictions.py @@ -0,0 +1,14 @@ +from bob.pad.base.algorithm import Algorithm + + +class Predictions(Algorithm): + """An algorithm that takes the precomputed predictions and uses them for + scoring.""" + + def __init__(self, **kwargs): + super(Predictions, self).__init__( + **kwargs) + + def score(self, predictions): + # Assuming the predictions are the output of a softmax layer + return [predictions[1]] diff --git a/bob/pad/base/algorithm/__init__.py b/bob/pad/base/algorithm/__init__.py index 05d1626..be0ba4b 100644 --- a/bob/pad/base/algorithm/__init__.py +++ b/bob/pad/base/algorithm/__init__.py @@ -3,7 +3,10 @@ from .SVM import SVM from .OneClassGMM import OneClassGMM from .LogRegr import LogRegr from .SVMCascadePCA import SVMCascadePCA +from .Predictions import Predictions + +# to fix sphinx warnings of not able to find classes, when path is shortened def __appropriate__(*args): """Says object was actually declared here, and not in the import module. Fixing sphinx warnings of not being able to find classes, when path is @@ -28,5 +31,8 @@ __appropriate__( OneClassGMM, LogRegr, SVMCascadePCA, + Predictions, ) + +# gets sphinx autodoc done right - don't remove it __all__ = [_ for _ in dir() if not _.startswith('_')] diff --git a/bob/pad/base/test/test_algorithms.py b/bob/pad/base/test/test_algorithms.py index 555c185..830c0e6 100644 --- a/bob/pad/base/test/test_algorithms.py +++ b/bob/pad/base/test/test_algorithms.py @@ -6,9 +6,6 @@ from __future__ import print_function import numpy as np -from bob.io.base.test_utils import datafile -from bob.io.base import load - import bob.io.image # for image loading functionality import bob.bio.video import bob.pad.base @@ -18,8 +15,28 @@ from bob.pad.base.algorithm import OneClassGMM import random -from bob.pad.base.utils import convert_array_to_list_of_frame_cont, convert_list_of_frame_cont_to_array, \ +from bob.pad.base.utils import ( + convert_array_to_list_of_frame_cont, + convert_list_of_frame_cont_to_array, convert_frame_cont_to_array +) + +from bob.pad.base.database import PadFile +from bob.pad.base.algorithm import Predictions +from bob.pad.base import padfile_to_label + + +def test_prediction(): + alg = Predictions() + sample = [0, 1] + assert alg.score(sample)[0] == sample[1] + + +def test_padfile_to_label(): + f = PadFile(client_id='', path='', attack_type=None, file_id=1) + assert padfile_to_label(f) is True, padfile_to_label(f) + f = PadFile(client_id='', path='', attack_type='print', file_id=1) + assert padfile_to_label(f) is False, padfile_to_label(f) def test_video_svm_pad_algorithm(): @@ -144,8 +161,9 @@ def test_video_gmm_pad_algorithm(): assert (np.min(scores_attack) + 38.831260843070098) < 0.000001 assert (np.max(scores_attack) + 5.3633030621521272) < 0.000001 + def test_convert_list_of_frame_cont_to_array(): - + N = 1000 mu = 1 sigma = 1 @@ -155,4 +173,3 @@ def test_convert_list_of_frame_cont_to_array(): assert isinstance(features_array[0], np.ndarray) features_fm = convert_array_to_list_of_frame_cont(real_array) assert isinstance(features_fm[0], bob.bio.video.FrameContainer) - diff --git a/doc/implemented.rst b/doc/implemented.rst index d4d73e6..ea0874b 100644 --- a/doc/implemented.rst +++ b/doc/implemented.rst @@ -17,6 +17,7 @@ Only one base class that is presentation attack detection specific, ``Algorithm` .. autosummary:: bob.pad.base.algorithm.Algorithm + bob.pad.base.algorithm.Predictions Implementations ~~~~~~~~~~~~~~~ -- GitLab