From cad231fbc79688c0309f96b9f5b2d3d9b9beeadd Mon Sep 17 00:00:00 2001
From: Guillaume HEUSCH <guillaume.heusch@idiap.ch>
Date: Fri, 23 Feb 2018 17:09:02 +0100
Subject: [PATCH] [extractor] added the extraction of the freq spectrum for
 pulse signal

---
 bob/pad/face/extractor/FrequencySpectrum.py | 77 +++++++++++++++++++++
 bob/pad/face/extractor/__init__.py          |  2 +
 2 files changed, 79 insertions(+)
 create mode 100644 bob/pad/face/extractor/FrequencySpectrum.py

diff --git a/bob/pad/face/extractor/FrequencySpectrum.py b/bob/pad/face/extractor/FrequencySpectrum.py
new file mode 100644
index 00000000..d22dcde0
--- /dev/null
+++ b/bob/pad/face/extractor/FrequencySpectrum.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python
+# encoding: utf-8
+
+import numpy
+
+from bob.bio.base.extractor import Extractor
+
+import logging
+logger = logging.getLogger("bob.pad.face")
+
+from scipy.signal import welch
+
+
+class FrequencySpectrum(Extractor, object):
+  """
+  Compute the Frequency Spectrum of the given signal.
+
+  The computation is made using Welch's algorithm.
+
+  **Parameters:**
+
+  framerate: int
+    The sampling frequency of the signal (i.e the framerate ...) 
+
+  nsegments: int
+    Number of overlapping segments in Welch procedure
+
+  nfft: int
+    Number of points to compute the FFT
+
+  debug: boolean
+    Plot stuff
+  """
+  def __init__(self, framerate=25, nsegments=12, nfft=256, debug=False, **kwargs):
+
+    super(FrequencySpectrum, self).__init__()
+    
+    self.framerate = framerate
+    self.nsegments = nsegments
+    self.nfft = nfft
+    self.debug = debug
+
+  def __call__(self, signal):
+    """
+    Compute the frequency spectrum for the given signal.
+
+    **Parameters:**
+
+    signal: numpy.array 
+      The signal
+
+    **Returns:**
+
+      freq: numpy.array 
+       the frequency spectrum 
+    """
+    output_dim = int((self.nfft / 2) + 1)
+   
+    # we have a single pulse signal
+    if signal.ndim == 1:
+      f, psd = welch(signal, self.framerate, nperseg=self.nsegments, nfft=self.nfft)
+
+    # we have 3 pulse signal (Li's preprocessing)
+    # in this case, return the signal corresponding to the green channel
+    if signal.ndim == 2 and (signal.shape[1] == 3):
+      psds = numpy.zeros((3, output_dim))
+      for i in range(3):
+        f, psds[i] = welch(signal[:, i], self.framerate, nperseg=self.nsegments, nfft=self.nfft)
+      psd = psds[1]
+      
+    if self.debug: 
+      from matplotlib import pyplot
+      pyplot.semilogy(f, psd, 'k')
+      pyplot.title('Power spectrum of the signal')
+      pyplot.show()
+
+    return psd
diff --git a/bob/pad/face/extractor/__init__.py b/bob/pad/face/extractor/__init__.py
index 310989b4..bd3f3bb2 100644
--- a/bob/pad/face/extractor/__init__.py
+++ b/bob/pad/face/extractor/__init__.py
@@ -5,6 +5,8 @@ from .VideoDataLoader import VideoDataLoader
 from .VideoQualityMeasure import VideoQualityMeasure
 from .FrameDiffFeatures import FrameDiffFeatures
 
+from .FrequencySpectrum import FrequencySpectrum
+
 def __appropriate__(*args):
     """Says object was actually declared here, and not in the import module.
     Fixing sphinx warnings of not being able to find classes, when path is
-- 
GitLab