Skip to content
Snippets Groups Projects
Commit bdd6bf10 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[extractor] fixed the extractor for Li feature, using FFT

parent d1df496f
No related branches found
No related tags found
1 merge request!53WIP: rPPG as features for PAD
...@@ -33,21 +33,17 @@ class FreqFeatures(Extractor, object): ...@@ -33,21 +33,17 @@ class FreqFeatures(Extractor, object):
framerate: int framerate: int
The sampling frequency of the signal (i.e the framerate ...) The sampling frequency of the signal (i.e the framerate ...)
nsegments: int
Number of overlapping segments in Welch procedure
nfft: int nfft: int
Number of points to compute the FFT Number of points to compute the FFT
debug: boolean debug: boolean
Plot stuff Plot stuff
""" """
def __init__(self, framerate=25, nsegments=12, nfft=128, debug=False, **kwargs): def __init__(self, framerate=25, nfft=512, debug=False, **kwargs):
super(FreqFeatures, self).__init__() super(FreqFeatures, self).__init__()
self.framerate = framerate self.framerate = framerate
self.nsegments = nsegments
self.nfft = nfft self.nfft = nfft
self.debug = debug self.debug = debug
...@@ -65,42 +61,67 @@ class FreqFeatures(Extractor, object): ...@@ -65,42 +61,67 @@ class FreqFeatures(Extractor, object):
freq: numpy.array freq: numpy.array
the frequency spectrum the frequency spectrum
""" """
# sanity check
assert signal.ndim == 2 and signal.shape[1] == 3, "You should provide 3 pulse signals" assert signal.ndim == 2 and signal.shape[1] == 3, "You should provide 3 pulse signals"
for i in range(3):
if numpy.isnan(numpy.sum(signal[:, i])):
return
feature = numpy.zeros(6) feature = numpy.zeros(6)
# when keypoints have not been detected, the pulse is zero everywhere
# hence, no psd and no features
zero_counter = 0
for i in range(3):
if numpy.sum(signal[:, i]) == 0:
zero_counter += 1
if zero_counter == 3:
logger.warn("Feature is all zeros")
return feature
# get the frequency spectrum
spectrum_dim = int((self.nfft / 2) + 1) spectrum_dim = int((self.nfft / 2) + 1)
psds = numpy.zeros((3, spectrum_dim)) ffts = numpy.zeros((3, spectrum_dim))
f = numpy.fft.fftfreq(self.nfft) * self.framerate
f = abs(f[:spectrum_dim])
for i in range(3): for i in range(3):
f, psds[i] = welch(signal[:, i], self.framerate, nperseg=self.nsegments, nfft=self.nfft) ffts[i] = abs(numpy.fft.rfft(signal[:, i], n=self.nfft))
# find the max of the frequency spectrum in the range of interest # find the max of the frequency spectrum in the range of interest
first = numpy.where(f > 0.7)[0] first = numpy.where(f > 0.7)[0]
last = numpy.where(f < 4)[0] last = numpy.where(f < 4)[0]
first_index = first[0] first_index = first[0]
last_index = last[-1] last_index = last[-1]
range_of_interest = range(first_index, last_index + 1, 1) range_of_interest = range(first_index, last_index + 1, 1)
# build the feature vector
for i in range(3): for i in range(3):
total_power = numpy.sum(psds[i, range_of_interest]) total_power = numpy.sum(ffts[i, range_of_interest])
max_power = numpy.max(psds[i, range_of_interest]) max_power = numpy.max(ffts[i, range_of_interest])
feature[i] = max_power feature[i] = max_power
feature[i+3] = max_power / total_power if total_power == 0:
print (max_power) feature[i+3] = 0
print (max_power / total_power) else:
feature[i+3] = max_power / total_power
if self.debug:
# plot stuff, if asked for
if self.debug:
from matplotlib import pyplot
for i in range(3):
max_idx = numpy.argmax(ffts[i, range_of_interest])
f_max = f[range_of_interest[max_idx]] f_max = f[range_of_interest[max_idx]]
max_idx = numpy.argmax(psds[i, range_of_interest]) logger.debug("Inferred HR = {}".format(f_max*60))
from matplotlib import pyplot pyplot.plot(f, ffts[i], 'k')
pyplot.semilogy(f, psds[i], 'k')
xmax, xmin, ymax, ymin = pyplot.axis() xmax, xmin, ymax, ymin = pyplot.axis()
pyplot.vlines(f[range_of_interest[max_idx]], ymin, ymax, color='red') pyplot.vlines(f[range_of_interest[max_idx]], ymin, ymax, color='red')
pyplot.title('Power spectrum of the signal') pyplot.vlines(f[first_index], ymin, ymax, color='green')
pyplot.vlines(f[last_index], ymin, ymax, color='green')
pyplot.show() pyplot.show()
print(feature) if numpy.isnan(numpy.sum(feature)):
import sys logger.warn("Feature not extracted")
sys.exit() return
return feature return feature
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment