Skip to content
Snippets Groups Projects
Commit 0b09f933 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[extractor] fixed the CHROM and SSR extractor when dealing with either filenames or FrameContainer

parent a746fcfc
No related branches found
No related tags found
1 merge request!53WIP: rPPG as features for PAD
......@@ -15,8 +15,6 @@ import bob.ip.skincolorfilter
import logging
logger = logging.getLogger("bob.pad.face")
from bob.rppg.base.utils import crop_face
from bob.rppg.base.utils import build_bandpass_filter
......@@ -26,7 +24,6 @@ from bob.rppg.chrom.extract_utils import compute_gray_diff
from bob.rppg.chrom.extract_utils import select_stable_frames
class Chrom(Extractor, object):
"""
Extract pulse signal according to the CHROM algorithm
......@@ -52,8 +49,11 @@ class Chrom(Extractor, object):
The percentage of frames you want to select where the
signal is "stable". 0 mean all the sequence.
debug: boolean
Plot some stuff
"""
def __init__(self, skin_threshold=0.5, skin_init=False, framerate=25, bp_order=32, window_size=0, motion=0.0, **kwargs):
def __init__(self, skin_threshold=0.5, skin_init=False, framerate=25, bp_order=32, window_size=0, motion=0.0, debug=False, **kwargs):
super(Chrom, self).__init__()
......@@ -63,6 +63,7 @@ class Chrom(Extractor, object):
self.bp_order = bp_order
self.window_size = window_size
self.motion = motion
self.debug = debug
self.skin_filter = bob.ip.skincolorfilter.SkinColorFilter()
......@@ -78,24 +79,23 @@ class Chrom(Extractor, object):
see ``bob.bio.video.utils.FrameContainer`` for further details.
If string, the name of the file to load the video data from is
defined in it. String is possible only when empty preprocessor is
used. In this case video data is loaded directly from the database.
used. In this case video data is loaded directly from the database
and not using any high or low-level db packages (so beware).
**Returns:**
pulse: FrameContainer
Quality Measures for each frame stored in the FrameContainer.
pulse: numpy.array
The pulse signal
"""
# load video based on the filename
assert isinstance(frames, six.string_types)
if isinstance(frames, six.string_types):
video_loader = VideoDataLoader()
video = video_loader(frames)
video = video.as_array()
else:
video = frames
video = video.as_array()
nb_frames = video.shape[0]
output_data = numpy.zeros(nb_frames, dtype='float64')
chrom = numpy.zeros((nb_frames, 2), dtype='float64')
# build the bandpass filter one and for all
......@@ -104,7 +104,14 @@ class Chrom(Extractor, object):
counter = 0
previous_bbox = None
for i, frame in enumerate(video):
logger.debug("Processing frame {}/{}".format(counter, nb_frames))
if self.debug:
from matplotlib import pyplot
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(frame, 2),2))
pyplot.show()
try:
bbox, quality = bob.ip.facedetect.detect_single_face(frame)
except:
......@@ -118,9 +125,10 @@ class Chrom(Extractor, object):
face = crop_face(frame, bbox, bbox.size[1])
#from matplotlib import pyplot
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(face, 2),2))
#pyplot.show()
if self.debug:
from matplotlib import pyplot
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(face, 2),2))
pyplot.show()
# skin filter
if counter == 0 or self.skin_init:
......@@ -128,13 +136,13 @@ class Chrom(Extractor, object):
logger.debug("Skin color parameters:\nmean\n{0}\ncovariance\n{1}".format(self.skin_filter.mean, self.skin_filter.covariance))
skin_mask = self.skin_filter.get_skin_mask(face, self.skin_threshold)
#from matplotlib import pyplot
#skin_mask_image = numpy.copy(face)
#skin_mask_image[:, skin_mask] = 255
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(skin_mask_image, 2),2))
#pyplot.show()
if self.debug:
from matplotlib import pyplot
skin_mask_image = numpy.copy(face)
skin_mask_image[:, skin_mask] = 255
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(skin_mask_image, 2),2))
pyplot.show()
logger.debug("Processing frame {}/{}".format(counter, nb_frames))
# sometimes skin is not detected !
if numpy.count_nonzero(skin_mask) != 0:
......@@ -167,13 +175,14 @@ class Chrom(Extractor, object):
logger.info("Stable segment -> {0} - {1}".format(index, index + n_stable_frames_to_keep))
chrom = chrom[index:(index + n_stable_frames_to_keep),:]
#from matplotlib import pyplot
#f, axarr = pyplot.subplots(2, sharex=True)
#axarr[0].plot(range(chrom.shape[0]), chrom[:, 0], 'k')
#axarr[0].set_title("X value in the chrominance subspace")
#axarr[1].plot(range(chrom.shape[0]), chrom[:, 1], 'k')
#axarr[1].set_title("Y value in the chrominance subspace")
#pyplot.show()
if self.debug:
from matplotlib import pyplot
f, axarr = pyplot.subplots(2, sharex=True)
axarr[0].plot(range(chrom.shape[0]), chrom[:, 0], 'k')
axarr[0].set_title("X value in the chrominance subspace")
axarr[1].plot(range(chrom.shape[0]), chrom[:, 1], 'k')
axarr[1].set_title("Y value in the chrominance subspace")
pyplot.show()
# now that we have the chrominance signals, apply bandpass
from scipy.signal import filtfilt
......@@ -182,13 +191,14 @@ class Chrom(Extractor, object):
x_bandpassed = filtfilt(bandpass_filter, numpy.array([1]), chrom[:, 0])
y_bandpassed = filtfilt(bandpass_filter, numpy.array([1]), chrom[:, 1])
#from matplotlib import pyplot
#f, axarr = pyplot.subplots(2, sharex=True)
#axarr[0].plot(range(x_bandpassed.shape[0]), x_bandpassed, 'k')
#axarr[0].set_title("X bandpassed")
#axarr[1].plot(range(y_bandpassed.shape[0]), y_bandpassed, 'k')
#axarr[1].set_title("Y bandpassed")
#pyplot.show()
if self.debug:
from matplotlib import pyplot
f, axarr = pyplot.subplots(2, sharex=True)
axarr[0].plot(range(x_bandpassed.shape[0]), x_bandpassed, 'k')
axarr[0].set_title("X bandpassed")
axarr[1].plot(range(y_bandpassed.shape[0]), y_bandpassed, 'k')
axarr[1].set_title("Y bandpassed")
pyplot.show()
# build the final pulse signal
alpha = numpy.std(x_bandpassed) / numpy.std(y_bandpassed)
......@@ -206,11 +216,11 @@ class Chrom(Extractor, object):
sw *= numpy.hanning(window_size)
pulse[w:w+window_size] += sw
#from matplotlib import pyplot
#f, axarr = pyplot.subplots(1)
#pyplot.plot(range(pulse.shape[0]), pulse, 'k')
#pyplot.title("Pulse signal")
#pyplot.show()
if self.debug:
from matplotlib import pyplot
f, axarr = pyplot.subplots(1)
pyplot.plot(range(pulse.shape[0]), pulse, 'k')
pyplot.title("Pulse signal")
pyplot.show()
#output_data = pulse
return pulse
......@@ -15,7 +15,6 @@ import bob.ip.skincolorfilter
import logging
logger = logging.getLogger("bob.pad.face")
from bob.rppg.base.utils import crop_face
from bob.rppg.ssr.ssr_utils import get_eigen
......@@ -23,7 +22,6 @@ from bob.rppg.ssr.ssr_utils import plot_eigenvectors
from bob.rppg.ssr.ssr_utils import build_P
class SSR(Extractor, object):
"""
Extract pulse signal according to the SSR algorithm
......@@ -39,8 +37,11 @@ class SSR(Extractor, object):
stride: int
The temporal stride.
debug: boolean
Plot some stuff
"""
def __init__(self, skin_threshold=0.5, skin_init=False, stride=25, **kwargs):
def __init__(self, skin_threshold=0.5, skin_init=False, stride=25, debug=False, **kwargs):
super(SSR, self).__init__()
......@@ -62,20 +63,20 @@ class SSR(Extractor, object):
see ``bob.bio.video.utils.FrameContainer`` for further details.
If string, the name of the file to load the video data from is
defined in it. String is possible only when empty preprocessor is
used. In this case video data is loaded directly from the database.
used. In this case video data is loaded directly from the database
and not using any high or low-level db packages (so beware).
**Returns:**
pulse: FrameContainer
Quality Measures for each frame stored in the FrameContainer.
pulse: numpy.array
The pulse signal
"""
# load video based on the filename
assert isinstance(frames, six.string_types)
if isinstance(frames, six.string_types):
video_loader = VideoDataLoader()
video = video_loader(frames)
else:
video = frames
video = video.as_array()
nb_frames = video.shape[0]
......@@ -86,16 +87,19 @@ class SSR(Extractor, object):
eigenvalues = numpy.zeros((3, nb_frames), dtype='float64')
eigenvectors = numpy.zeros((3, 3, nb_frames), dtype='float64')
### LET'S GO
#XXX
plot = True
counter = 0
previous_bbox = None
previous_skin_pixels = None
for i, frame in enumerate(video):
logger.debug("Processing frame %d/%d...", i, nb_frames)
if self.debug:
from matplotlib import pyplot
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(frame, 2),2))
pyplot.show()
try:
bbox, quality = bob.ip.facedetect.detect_single_face(frame)
except:
......@@ -104,43 +108,54 @@ class SSR(Extractor, object):
face = crop_face(frame, bbox, bbox.size[1])
#from matplotlib import pyplot
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(face, 2),2))
#pyplot.show()
if self.debug:
from matplotlib import pyplot
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(face, 2),2))
pyplot.show()
# skin filter
if counter == 0 or self.skin_init:
self.skin_filter.estimate_gaussian_parameters(face)
logger.debug("Skin color parameters:\nmean\n{0}\ncovariance\n{1}".format(self.skin_filter.mean, self.skin_filter.covariance))
skin_mask = self.skin_filter.get_skin_mask(face, self.skin_threshold)
skin_pixels = face[:, skin_mask]
#from matplotlib import pyplot
#skin_mask_image = numpy.copy(face)
#skin_mask_image[:, skin_mask] = 255
#pyplot.title("skin pixels in frame {0}".format(i))
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(skin_mask_image, 2),2))
#pyplot.show()
skin_pixels = skin_pixels.astype('float64') / 255.0
if self.debug:
from matplotlib import pyplot
skin_mask_image = numpy.copy(face)
skin_mask_image[:, skin_mask] = 255
pyplot.title("skin pixels in frame {0}".format(i))
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(skin_mask_image, 2),2))
pyplot.show()
# nos skin pixels have ben detected ... using the previous ones
if skin_pixels.shape[1] == 0:
skin_pixels = previous_skin_pixels
logger.warn("No skin pixels detected, using the previous ones")
# build c matrix and get eigenvectors and eigenvalues
eigenvalues[:, counter], eigenvectors[:, :, counter] = get_eigen(skin_pixels)
#plot_eigenvectors(skin_pixels, eigenvectors[:, :, counter])
if self.debug:
plot_eigenvectors(skin_pixels, eigenvectors[:, :, counter])
# build P and add it to the pulse signal
if counter >= self.stride:
tau = counter - self.stride
p = build_P(counter, self.stride, eigenvectors, eigenvalues)
output_data[tau:counter] += (p - numpy.mean(p))
previous_bbox = bbox
previous_skin_pixels = skin_pixels
counter += 1
# plot the pulse signal
#import matplotlib.pyplot as plt
#fig = plt.figure()
#ax = fig.add_subplot(111)
#ax.plot(range(nb_frames), output_data)
#plt.show()
if self.debug:
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(range(nb_frames), output_data)
plt.show()
return output_data
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment