Commit c4a7ac71 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Add a predictions algorithm useful for deeplearning

parent 9d65b96a
Pipeline #17482 passed with stage
in 35 minutes and 1 second
from bob.pad.base.algorithm import Algorithm
class Predictions(Algorithm):
"""An algorithm that takes the precomputed predictions and uses them for
scoring."""
def __init__(self, **kwargs):
super(Predictions, self).__init__(
**kwargs)
def score(self, predictions):
# Assuming the predictions are the output of a softmax layer
return [predictions[1]]
......@@ -3,7 +3,10 @@ from .SVM import SVM
from .OneClassGMM import OneClassGMM
from .LogRegr import LogRegr
from .SVMCascadePCA import SVMCascadePCA
from .Predictions import Predictions
# to fix sphinx warnings of not able to find classes, when path is shortened
def __appropriate__(*args):
"""Says object was actually declared here, and not in the import module.
Fixing sphinx warnings of not being able to find classes, when path is
......@@ -28,5 +31,8 @@ __appropriate__(
OneClassGMM,
LogRegr,
SVMCascadePCA,
Predictions,
)
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
......@@ -6,9 +6,6 @@ from __future__ import print_function
import numpy as np
from bob.io.base.test_utils import datafile
from bob.io.base import load
import bob.io.image # for image loading functionality
import bob.bio.video
import bob.pad.base
......@@ -18,8 +15,28 @@ from bob.pad.base.algorithm import OneClassGMM
import random
from bob.pad.base.utils import convert_array_to_list_of_frame_cont, convert_list_of_frame_cont_to_array, \
from bob.pad.base.utils import (
convert_array_to_list_of_frame_cont,
convert_list_of_frame_cont_to_array,
convert_frame_cont_to_array
)
from bob.pad.base.database import PadFile
from bob.pad.base.algorithm import Predictions
from bob.pad.base import padfile_to_label
def test_prediction():
alg = Predictions()
sample = [0, 1]
assert alg.score(sample)[0] == sample[1]
def test_padfile_to_label():
f = PadFile(client_id='', path='', attack_type=None, file_id=1)
assert padfile_to_label(f) is True, padfile_to_label(f)
f = PadFile(client_id='', path='', attack_type='print', file_id=1)
assert padfile_to_label(f) is False, padfile_to_label(f)
def test_video_svm_pad_algorithm():
......@@ -144,8 +161,9 @@ def test_video_gmm_pad_algorithm():
assert (np.min(scores_attack) + 38.831260843070098) < 0.000001
assert (np.max(scores_attack) + 5.3633030621521272) < 0.000001
def test_convert_list_of_frame_cont_to_array():
N = 1000
mu = 1
sigma = 1
......@@ -155,4 +173,3 @@ def test_convert_list_of_frame_cont_to_array():
assert isinstance(features_array[0], np.ndarray)
features_fm = convert_array_to_list_of_frame_cont(real_array)
assert isinstance(features_fm[0], bob.bio.video.FrameContainer)
......@@ -17,6 +17,7 @@ Only one base class that is presentation attack detection specific, ``Algorithm`
.. autosummary::
bob.pad.base.algorithm.Algorithm
bob.pad.base.algorithm.Predictions
Implementations
~~~~~~~~~~~~~~~
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment