From c4a7ac7181152a2ca6aec5dc7d32b51aa039cf69 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 9 Mar 2018 12:13:13 +0100
Subject: [PATCH] Add a predictions algorithm useful for deeplearning

---
 bob/pad/base/algorithm/Predictions.py | 14 +++++++++++++
 bob/pad/base/algorithm/__init__.py    |  6 ++++++
 bob/pad/base/test/test_algorithms.py  | 29 +++++++++++++++++++++------
 doc/implemented.rst                   |  1 +
 4 files changed, 44 insertions(+), 6 deletions(-)
 create mode 100644 bob/pad/base/algorithm/Predictions.py

diff --git a/bob/pad/base/algorithm/Predictions.py b/bob/pad/base/algorithm/Predictions.py
new file mode 100644
index 0000000..6eda214
--- /dev/null
+++ b/bob/pad/base/algorithm/Predictions.py
@@ -0,0 +1,14 @@
+from bob.pad.base.algorithm import Algorithm
+
+
+class Predictions(Algorithm):
+    """An algorithm that takes the precomputed predictions and uses them for
+    scoring."""
+
+    def __init__(self, **kwargs):
+        super(Predictions, self).__init__(
+            **kwargs)
+
+    def score(self, predictions):
+        # Assuming the predictions are the output of a softmax layer
+        return [predictions[1]]
diff --git a/bob/pad/base/algorithm/__init__.py b/bob/pad/base/algorithm/__init__.py
index 05d1626..be0ba4b 100644
--- a/bob/pad/base/algorithm/__init__.py
+++ b/bob/pad/base/algorithm/__init__.py
@@ -3,7 +3,10 @@ from .SVM import SVM
 from .OneClassGMM import OneClassGMM
 from .LogRegr import LogRegr
 from .SVMCascadePCA import SVMCascadePCA
+from .Predictions import Predictions
 
+
+# to fix sphinx warnings of not able to find classes, when path is shortened
 def __appropriate__(*args):
     """Says object was actually declared here, and not in the import module.
     Fixing sphinx warnings of not being able to find classes, when path is
@@ -28,5 +31,8 @@ __appropriate__(
     OneClassGMM,
     LogRegr,
     SVMCascadePCA,
+    Predictions,
 )
+
+# gets sphinx autodoc done right - don't remove it
 __all__ = [_ for _ in dir() if not _.startswith('_')]
diff --git a/bob/pad/base/test/test_algorithms.py b/bob/pad/base/test/test_algorithms.py
index 555c185..830c0e6 100644
--- a/bob/pad/base/test/test_algorithms.py
+++ b/bob/pad/base/test/test_algorithms.py
@@ -6,9 +6,6 @@ from __future__ import print_function
 
 import numpy as np
 
-from bob.io.base.test_utils import datafile
-from bob.io.base import load
-
 import bob.io.image  # for image loading functionality
 import bob.bio.video
 import bob.pad.base
@@ -18,8 +15,28 @@ from bob.pad.base.algorithm import OneClassGMM
 
 import random
 
-from bob.pad.base.utils import convert_array_to_list_of_frame_cont, convert_list_of_frame_cont_to_array, \
+from bob.pad.base.utils import (
+    convert_array_to_list_of_frame_cont,
+    convert_list_of_frame_cont_to_array,
     convert_frame_cont_to_array
+)
+
+from bob.pad.base.database import PadFile
+from bob.pad.base.algorithm import Predictions
+from bob.pad.base import padfile_to_label
+
+
+def test_prediction():
+    alg = Predictions()
+    sample = [0, 1]
+    assert alg.score(sample)[0] == sample[1]
+
+
+def test_padfile_to_label():
+    f = PadFile(client_id='', path='', attack_type=None, file_id=1)
+    assert padfile_to_label(f) is True, padfile_to_label(f)
+    f = PadFile(client_id='', path='', attack_type='print', file_id=1)
+    assert padfile_to_label(f) is False, padfile_to_label(f)
 
 
 def test_video_svm_pad_algorithm():
@@ -144,8 +161,9 @@ def test_video_gmm_pad_algorithm():
     assert (np.min(scores_attack) + 38.831260843070098) < 0.000001
     assert (np.max(scores_attack) + 5.3633030621521272) < 0.000001
 
+
 def test_convert_list_of_frame_cont_to_array():
-  
+
   N = 1000
   mu = 1
   sigma = 1
@@ -155,4 +173,3 @@ def test_convert_list_of_frame_cont_to_array():
   assert isinstance(features_array[0], np.ndarray)
   features_fm = convert_array_to_list_of_frame_cont(real_array)
   assert isinstance(features_fm[0], bob.bio.video.FrameContainer)
-
diff --git a/doc/implemented.rst b/doc/implemented.rst
index d4d73e6..ea0874b 100644
--- a/doc/implemented.rst
+++ b/doc/implemented.rst
@@ -17,6 +17,7 @@ Only one base class that is presentation attack detection specific, ``Algorithm`
 
 .. autosummary::
    bob.pad.base.algorithm.Algorithm
+   bob.pad.base.algorithm.Predictions
 
 Implementations
 ~~~~~~~~~~~~~~~
-- 
GitLab