Skip to content
Snippets Groups Projects
Commit f58d5fde authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[extractor] fixed extractors for pulse-based PAD

parent caac10bf
Branches
Tags
1 merge request!53WIP: rPPG as features for PAD
Pipeline #
#!/usr/bin/env python
# encoding: utf-8
import numpy
from bob.bio.base.extractor import Extractor
import logging
logger = logging.getLogger("bob.pad.face")
class FFTFeatures(Extractor, object):
"""
Compute the Frequency Spectrum of the given signal.
The computation is made using numpy's rfft routine
**Parameters:**
framerate: int
The sampling frequency of the signal (i.e the framerate ...)
nfft: int
Number of points to compute the FFT
debug: boolean
Plot stuff
"""
def __init__(self, framerate=25, nfft=256, concat=False, debug=False, **kwargs):
super(FFTFeatures, self).__init__(**kwargs)
self.framerate = framerate
self.nfft = nfft
self.concat = concat
self.debug = debug
def __call__(self, signal):
"""
Compute the frequency spectrum for the given signal.
**Parameters:**
signal: numpy.array
The signal
**Returns:**
freq: numpy.array
the frequency spectrum
"""
# sanity check
if signal.ndim == 1:
if numpy.isnan(numpy.sum(signal)):
return
if signal.ndim == 2 and (signal.shape[1] == 3):
if numpy.isnan(numpy.sum(signal[:, 1])):
return
output_dim = int((self.nfft / 2) + 1)
# get the frequencies
f = numpy.fft.fftfreq(self.nfft) * self.framerate
# we have a single pulse signal
if signal.ndim == 1:
fft = abs(numpy.fft.rfft(signal, n=self.nfft))
# we have 3 pulse signal (Li's preprocessing)
# in this case, return the signal corresponding to the green channel
if signal.ndim == 2 and (signal.shape[1] == 3):
ffts = numpy.zeros((3, output_dim))
for i in range(3):
ffts[i] = abs(numpy.fft.rfft(signal[:, i], n=self.nfft))
if self.concat:
fft = numpy.concatenate([ffts[0], ffts[1], ffts[2]])
else:
fft = ffts[1]
if self.debug:
from matplotlib import pyplot
pyplot.plot(f, fft, 'k')
pyplot.title('Power spectrum of the signal')
pyplot.show()
return fft
#!/usr/bin/env python
# encoding: utf-8
import numpy
from bob.bio.base.extractor import Extractor
import logging
logger = logging.getLogger("bob.pad.face")
from scipy.signal import welch
class FrequencySpectrum(Extractor, object):
"""
Compute the Frequency Spectrum of the given signal.
The computation is made using Welch's algorithm.
**Parameters:**
framerate: int
The sampling frequency of the signal (i.e the framerate ...)
nsegments: int
Number of overlapping segments in Welch procedure
nfft: int
Number of points to compute the FFT
debug: boolean
Plot stuff
"""
def __init__(self, framerate=25, nsegments=12, nfft=256, debug=False, **kwargs):
super(FrequencySpectrum, self).__init__()
self.framerate = framerate
self.nsegments = nsegments
self.nfft = nfft
self.debug = debug
def __call__(self, signal):
"""
Compute the frequency spectrum for the given signal.
**Parameters:**
signal: numpy.array
The signal
**Returns:**
freq: numpy.array
the frequency spectrum
"""
# sanity check
if signal.ndim == 1:
if numpy.isnan(numpy.sum(signal)):
return
if signal.ndim == 2 and (signal.shape[1] == 3):
if numpy.isnan(numpy.sum(signal[:, 1])):
return
output_dim = int((self.nfft / 2) + 1)
# we have a single pulse signal
if signal.ndim == 1:
f, psd = welch(signal, self.framerate, nperseg=self.nsegments, nfft=self.nfft)
# we have 3 pulse signal (Li's preprocessing)
# in this case, return the signal corresponding to the green channel
if signal.ndim == 2 and (signal.shape[1] == 3):
psds = numpy.zeros((3, output_dim))
for i in range(3):
f, psds[i] = welch(signal[:, i], self.framerate, nperseg=self.nsegments, nfft=self.nfft)
psd = psds[1]
if self.debug:
from matplotlib import pyplot
pyplot.semilogy(f, psd, 'k')
pyplot.title('Power spectrum of the signal')
pyplot.show()
return psd
...@@ -5,15 +5,13 @@ import numpy ...@@ -5,15 +5,13 @@ import numpy
from bob.bio.base.extractor import Extractor from bob.bio.base.extractor import Extractor
import logging from bob.core.log import setup
logger = logging.getLogger("bob.pad.face") logger = setup("bob.pad.face")
from scipy.fftpack import rfft from scipy.fftpack import rfft
class LTSS(Extractor, object): class LTSS(Extractor, object):
""" """Compute Long-term spectral statistics of a pulse signal.
Compute Long-term spectral statistics of a pulse signal.
The features are described in the following article: The features are described in the following article:
...@@ -29,38 +27,78 @@ class LTSS(Extractor, object): ...@@ -29,38 +27,78 @@ class LTSS(Extractor, object):
year = 2017 year = 2017
} }
**Parameters:** Attributes
----------
framerate: int framerate: int
The sampling frequency of the signal (i.e the framerate ...) The sampling frequency of the signal (i.e the framerate ...)
nfft: int nfft: int
Number of points to compute the FFT Number of points to compute the FFT
debug: bool
debug: boolean
Plot stuff Plot stuff
concat: bool
Flag if you would like to concatenate features from the 3 color channels
time: int
The length of the signal to consider (in seconds)
""" """
def __init__(self, window_size=25, framerate=25, nfft=64, concat=False, debug=False, **kwargs): def __init__(self, window_size=25, framerate=25, nfft=64, concat=False, debug=False, time=0, **kwargs):
"""Init function
Parameters
----------
window_size: int
The size of the window where FFT is computed
framerate: int
The sampling frequency of the signal (i.e the framerate ...)
nfft: int
Number of points to compute the FFT
concat: bool
Flag if you would like to concatenate features from the 3 color channels
debug: bool
Plot stuff
time: int
The length of the signal to consider (in seconds)
"""
super(LTSS, self).__init__() super(LTSS, self).__init__()
self.framerate = framerate self.framerate = framerate
self.nfft = nfft self.nfft = nfft
self.debug = debug self.debug = debug
self.window_size = window_size self.window_size = window_size
self.concat = concat self.concat = concat
self.time = time
def _get_ltss(self, signal): def _get_ltss(self, signal):
"""Compute long term spectral statistics for a signal
Parameters
----------
signal: numpy.ndarray
The signal
Returns
-------
ltss: numpy.ndarray
The spectral statistics of the signal.
"""
window_stride = int(self.window_size / 2)
# log-magnitude of DFT coefficients # log-magnitude of DFT coefficients
log_mags = [] log_mags = []
window_stride = int(self.window_size / 2)
# go through windows # go through windows
for w in range(0, (signal.shape[0] - self.window_size), window_stride): for w in range(0, (signal.shape[0] - self.window_size), window_stride):
fft = rfft(signal[w:w+self.window_size], n=self.nfft) fft = rfft(signal[w:w+self.window_size], n=self.nfft)
mags = numpy.zeros(int(self.nfft/2), dtype=numpy.float64) mags = numpy.zeros(int(self.nfft/2), dtype=numpy.float64)
mags[0] = abs(fft[0])
# XXX : bug was here (no clipping)
if abs(fft[0]) < 1:
mags[0] = 1
else:
mags[0] = abs(fft[0])
# XXX
index = 1 index = 1
for i in range(1, (fft.shape[0]-1), 2): for i in range(1, (fft.shape[0]-1), 2):
mags[index] = numpy.sqrt(fft[i]**2 + fft[i+1]**2) mags[index] = numpy.sqrt(fft[i]**2 + fft[i+1]**2)
...@@ -69,7 +107,6 @@ class LTSS(Extractor, object): ...@@ -69,7 +107,6 @@ class LTSS(Extractor, object):
index += 1 index += 1
log_mags.append(numpy.log(mags)) log_mags.append(numpy.log(mags))
# get the long term statistics
log_mags = numpy.array(log_mags) log_mags = numpy.array(log_mags)
mean = numpy.mean(log_mags, axis=0) mean = numpy.mean(log_mags, axis=0)
std = numpy.std(log_mags, axis=0) std = numpy.std(log_mags, axis=0)
...@@ -78,18 +115,18 @@ class LTSS(Extractor, object): ...@@ -78,18 +115,18 @@ class LTSS(Extractor, object):
def __call__(self, signal): def __call__(self, signal):
""" """Computes the long-term spectral statistics for given pulse signals.
Computes the long-term spectral statistics for a given signal.
**Parameters**
signal: numpy.array Parameters
----------
signal: numpy.ndarray
The signal The signal
**Returns:** Returns
-------
feature: numpy.ndarray
the computed LTSS features
feature: numpy.array
the long-term spectral statistics feature vector
""" """
# sanity check # sanity check
if signal.ndim == 1: if signal.ndim == 1:
...@@ -100,9 +137,34 @@ class LTSS(Extractor, object): ...@@ -100,9 +137,34 @@ class LTSS(Extractor, object):
if numpy.isnan(numpy.sum(signal[:, i])): if numpy.isnan(numpy.sum(signal[:, i])):
return return
# truncate the signal according to time
if self.time > 0:
number_of_frames = self.time * self.framerate
# check that the truncated signal is not longer
# than the original one
if number_of_frames < signal.shape[0]:
if signal.ndim == 1:
signal = signal[:number_of_frames]
if signal.ndim == 2 and (signal.shape[1] == 3):
new_signal = numpy.zeros((number_of_frames, 3))
for i in range(signal.shape[1]):
new_signal[:, i] = signal[:number_of_frames, i]
signal = new_signal
else:
logger.warning("Sequence should be truncated to {}, but only contains {} => keeping original one".format(number_of_frames, signal.shape[0]))
# also, be sure that the window_size is not bigger that the signal
if self.window_size > int(signal.shape[0] / 2):
self.window_size = int(signal.shape[0] / 2)
logger.warning("Window size reduced to {}".format(self.window_size))
# we have a single pulse
if signal.ndim == 1: if signal.ndim == 1:
feature = self._get_ltss(signal) feature = self._get_ltss(signal)
# pulse for the 3 color channels
if signal.ndim == 2 and (signal.shape[1] == 3): if signal.ndim == 2 and (signal.shape[1] == 3):
if not self.concat: if not self.concat:
......
...@@ -68,6 +68,7 @@ class LiFeatures(Extractor, object): ...@@ -68,6 +68,7 @@ class LiFeatures(Extractor, object):
------- -------
feature: numpy.ndarray feature: numpy.ndarray
the computed features the computed features
""" """
# sanity check # sanity check
assert signal.ndim == 2 and signal.shape[1] == 3, "You should provide 3 pulse signals" assert signal.ndim == 2 and signal.shape[1] == 3, "You should provide 3 pulse signals"
......
...@@ -5,13 +5,12 @@ import numpy ...@@ -5,13 +5,12 @@ import numpy
from bob.bio.base.extractor import Extractor from bob.bio.base.extractor import Extractor
import logging from bob.core.log import setup
logger = logging.getLogger("bob.pad.face") logger = setup("bob.pad.face")
class PPGSecure(Extractor, object): class PPGSecure(Extractor, object):
""" """Extract frequency spectra from pulse signals.
This class extract the frequency features from pulse signals.
The feature are extracted according to what is described in The feature are extracted according to what is described in
the following article: the following article:
...@@ -30,39 +29,48 @@ class PPGSecure(Extractor, object): ...@@ -30,39 +29,48 @@ class PPGSecure(Extractor, object):
year = 2017 year = 2017
} }
**Parameters:** Attributes
----------
framerate: int framerate: int
The sampling frequency of the signal (i.e the framerate ...) The sampling frequency of the signal (i.e the framerate ...)
nfft: int nfft: int
Number of points to compute the FFT Number of points to compute the FFT
debug: bool
debug: boolean
Plot stuff Plot stuff
""" """
def __init__(self, framerate=25, nfft=32, debug=False, **kwargs): def __init__(self, framerate=25, nfft=32, debug=False, **kwargs):
"""Init function
super(PPGSecure, self).__init__(**kwargs)
Parameters
----------
framerate: int
The sampling frequency of the signal (i.e the framerate ...)
nfft: int
Number of points to compute the FFT
debug: bool
Plot stuff
"""
super(PPGSecure, self).__init__(**kwargs)
self.framerate = framerate self.framerate = framerate
self.nfft = nfft self.nfft = nfft
self.debug = debug self.debug = debug
def __call__(self, signal): def __call__(self, signal):
""" """Compute and concatenate frequency spectra for the given signals.
Compute the frequency spectrum for the given signal.
**Parameters:**
signal: numpy.array Parameters
----------
signal: numpy.ndarray
The signal The signal
**Returns:** Returns
-------
freq: numpy.array fft: numpy.ndarray
the frequency spectrum the computed FFT features
""" """
# sanity check # sanity check
assert signal.shape[1] == 5, "You should provide 5 pulses" assert signal.shape[1] == 5, "You should provide 5 pulses"
...@@ -74,7 +82,7 @@ class PPGSecure(Extractor, object): ...@@ -74,7 +82,7 @@ class PPGSecure(Extractor, object):
# get the frequencies # get the frequencies
f = numpy.fft.fftfreq(self.nfft) * self.framerate f = numpy.fft.fftfreq(self.nfft) * self.framerate
# we have 5 pulse signal (Li's preprocessing) # we have 5x3 pulse signals, in different regions across 3 channels
ffts = numpy.zeros((5, output_dim)) ffts = numpy.zeros((5, output_dim))
for i in range(5): for i in range(5):
ffts[i] = abs(numpy.fft.rfft(signal[:, i], n=self.nfft)) ffts[i] = abs(numpy.fft.rfft(signal[:, i], n=self.nfft))
...@@ -94,5 +102,4 @@ class PPGSecure(Extractor, object): ...@@ -94,5 +102,4 @@ class PPGSecure(Extractor, object):
logger.warn("Feature not extracted") logger.warn("Feature not extracted")
return return
return fft return fft
...@@ -5,10 +5,7 @@ from .VideoDataLoader import VideoDataLoader ...@@ -5,10 +5,7 @@ from .VideoDataLoader import VideoDataLoader
from .VideoQualityMeasure import VideoQualityMeasure from .VideoQualityMeasure import VideoQualityMeasure
from .FrameDiffFeatures import FrameDiffFeatures from .FrameDiffFeatures import FrameDiffFeatures
from .FrequencySpectrum import FrequencySpectrum from .LiFeatures import LiFeatures
from .FreqFeatures import FreqFeatures
from .NormalizeLength import NormalizeLength
from .FFTFeatures import FFTFeatures
from .LTSS import LTSS from .LTSS import LTSS
from .PPGSecure import PPGSecure from .PPGSecure import PPGSecure
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment