Skip to content
Snippets Groups Projects
Commit 0b09f933 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[extractor] fixed the CHROM and SSR extractor when dealing with either filenames or FrameContainer

parent a746fcfc
Branches
Tags
1 merge request!53WIP: rPPG as features for PAD
...@@ -15,8 +15,6 @@ import bob.ip.skincolorfilter ...@@ -15,8 +15,6 @@ import bob.ip.skincolorfilter
import logging import logging
logger = logging.getLogger("bob.pad.face") logger = logging.getLogger("bob.pad.face")
from bob.rppg.base.utils import crop_face from bob.rppg.base.utils import crop_face
from bob.rppg.base.utils import build_bandpass_filter from bob.rppg.base.utils import build_bandpass_filter
...@@ -26,7 +24,6 @@ from bob.rppg.chrom.extract_utils import compute_gray_diff ...@@ -26,7 +24,6 @@ from bob.rppg.chrom.extract_utils import compute_gray_diff
from bob.rppg.chrom.extract_utils import select_stable_frames from bob.rppg.chrom.extract_utils import select_stable_frames
class Chrom(Extractor, object): class Chrom(Extractor, object):
""" """
Extract pulse signal according to the CHROM algorithm Extract pulse signal according to the CHROM algorithm
...@@ -52,8 +49,11 @@ class Chrom(Extractor, object): ...@@ -52,8 +49,11 @@ class Chrom(Extractor, object):
The percentage of frames you want to select where the The percentage of frames you want to select where the
signal is "stable". 0 mean all the sequence. signal is "stable". 0 mean all the sequence.
debug: boolean
Plot some stuff
""" """
def __init__(self, skin_threshold=0.5, skin_init=False, framerate=25, bp_order=32, window_size=0, motion=0.0, **kwargs): def __init__(self, skin_threshold=0.5, skin_init=False, framerate=25, bp_order=32, window_size=0, motion=0.0, debug=False, **kwargs):
super(Chrom, self).__init__() super(Chrom, self).__init__()
...@@ -63,6 +63,7 @@ class Chrom(Extractor, object): ...@@ -63,6 +63,7 @@ class Chrom(Extractor, object):
self.bp_order = bp_order self.bp_order = bp_order
self.window_size = window_size self.window_size = window_size
self.motion = motion self.motion = motion
self.debug = debug
self.skin_filter = bob.ip.skincolorfilter.SkinColorFilter() self.skin_filter = bob.ip.skincolorfilter.SkinColorFilter()
...@@ -78,24 +79,23 @@ class Chrom(Extractor, object): ...@@ -78,24 +79,23 @@ class Chrom(Extractor, object):
see ``bob.bio.video.utils.FrameContainer`` for further details. see ``bob.bio.video.utils.FrameContainer`` for further details.
If string, the name of the file to load the video data from is If string, the name of the file to load the video data from is
defined in it. String is possible only when empty preprocessor is defined in it. String is possible only when empty preprocessor is
used. In this case video data is loaded directly from the database. used. In this case video data is loaded directly from the database
and not using any high or low-level db packages (so beware).
**Returns:** **Returns:**
pulse: FrameContainer pulse: numpy.array
Quality Measures for each frame stored in the FrameContainer. The pulse signal
""" """
# load video based on the filename
assert isinstance(frames, six.string_types)
if isinstance(frames, six.string_types): if isinstance(frames, six.string_types):
video_loader = VideoDataLoader() video_loader = VideoDataLoader()
video = video_loader(frames) video = video_loader(frames)
else:
video = video.as_array() video = frames
video = video.as_array()
nb_frames = video.shape[0] nb_frames = video.shape[0]
output_data = numpy.zeros(nb_frames, dtype='float64')
chrom = numpy.zeros((nb_frames, 2), dtype='float64') chrom = numpy.zeros((nb_frames, 2), dtype='float64')
# build the bandpass filter one and for all # build the bandpass filter one and for all
...@@ -104,7 +104,14 @@ class Chrom(Extractor, object): ...@@ -104,7 +104,14 @@ class Chrom(Extractor, object):
counter = 0 counter = 0
previous_bbox = None previous_bbox = None
for i, frame in enumerate(video): for i, frame in enumerate(video):
logger.debug("Processing frame {}/{}".format(counter, nb_frames))
if self.debug:
from matplotlib import pyplot
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(frame, 2),2))
pyplot.show()
try: try:
bbox, quality = bob.ip.facedetect.detect_single_face(frame) bbox, quality = bob.ip.facedetect.detect_single_face(frame)
except: except:
...@@ -118,9 +125,10 @@ class Chrom(Extractor, object): ...@@ -118,9 +125,10 @@ class Chrom(Extractor, object):
face = crop_face(frame, bbox, bbox.size[1]) face = crop_face(frame, bbox, bbox.size[1])
#from matplotlib import pyplot if self.debug:
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(face, 2),2)) from matplotlib import pyplot
#pyplot.show() pyplot.imshow(numpy.rollaxis(numpy.rollaxis(face, 2),2))
pyplot.show()
# skin filter # skin filter
if counter == 0 or self.skin_init: if counter == 0 or self.skin_init:
...@@ -128,13 +136,13 @@ class Chrom(Extractor, object): ...@@ -128,13 +136,13 @@ class Chrom(Extractor, object):
logger.debug("Skin color parameters:\nmean\n{0}\ncovariance\n{1}".format(self.skin_filter.mean, self.skin_filter.covariance)) logger.debug("Skin color parameters:\nmean\n{0}\ncovariance\n{1}".format(self.skin_filter.mean, self.skin_filter.covariance))
skin_mask = self.skin_filter.get_skin_mask(face, self.skin_threshold) skin_mask = self.skin_filter.get_skin_mask(face, self.skin_threshold)
#from matplotlib import pyplot if self.debug:
#skin_mask_image = numpy.copy(face) from matplotlib import pyplot
#skin_mask_image[:, skin_mask] = 255 skin_mask_image = numpy.copy(face)
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(skin_mask_image, 2),2)) skin_mask_image[:, skin_mask] = 255
#pyplot.show() pyplot.imshow(numpy.rollaxis(numpy.rollaxis(skin_mask_image, 2),2))
pyplot.show()
logger.debug("Processing frame {}/{}".format(counter, nb_frames))
# sometimes skin is not detected ! # sometimes skin is not detected !
if numpy.count_nonzero(skin_mask) != 0: if numpy.count_nonzero(skin_mask) != 0:
...@@ -167,13 +175,14 @@ class Chrom(Extractor, object): ...@@ -167,13 +175,14 @@ class Chrom(Extractor, object):
logger.info("Stable segment -> {0} - {1}".format(index, index + n_stable_frames_to_keep)) logger.info("Stable segment -> {0} - {1}".format(index, index + n_stable_frames_to_keep))
chrom = chrom[index:(index + n_stable_frames_to_keep),:] chrom = chrom[index:(index + n_stable_frames_to_keep),:]
#from matplotlib import pyplot if self.debug:
#f, axarr = pyplot.subplots(2, sharex=True) from matplotlib import pyplot
#axarr[0].plot(range(chrom.shape[0]), chrom[:, 0], 'k') f, axarr = pyplot.subplots(2, sharex=True)
#axarr[0].set_title("X value in the chrominance subspace") axarr[0].plot(range(chrom.shape[0]), chrom[:, 0], 'k')
#axarr[1].plot(range(chrom.shape[0]), chrom[:, 1], 'k') axarr[0].set_title("X value in the chrominance subspace")
#axarr[1].set_title("Y value in the chrominance subspace") axarr[1].plot(range(chrom.shape[0]), chrom[:, 1], 'k')
#pyplot.show() axarr[1].set_title("Y value in the chrominance subspace")
pyplot.show()
# now that we have the chrominance signals, apply bandpass # now that we have the chrominance signals, apply bandpass
from scipy.signal import filtfilt from scipy.signal import filtfilt
...@@ -182,13 +191,14 @@ class Chrom(Extractor, object): ...@@ -182,13 +191,14 @@ class Chrom(Extractor, object):
x_bandpassed = filtfilt(bandpass_filter, numpy.array([1]), chrom[:, 0]) x_bandpassed = filtfilt(bandpass_filter, numpy.array([1]), chrom[:, 0])
y_bandpassed = filtfilt(bandpass_filter, numpy.array([1]), chrom[:, 1]) y_bandpassed = filtfilt(bandpass_filter, numpy.array([1]), chrom[:, 1])
#from matplotlib import pyplot if self.debug:
#f, axarr = pyplot.subplots(2, sharex=True) from matplotlib import pyplot
#axarr[0].plot(range(x_bandpassed.shape[0]), x_bandpassed, 'k') f, axarr = pyplot.subplots(2, sharex=True)
#axarr[0].set_title("X bandpassed") axarr[0].plot(range(x_bandpassed.shape[0]), x_bandpassed, 'k')
#axarr[1].plot(range(y_bandpassed.shape[0]), y_bandpassed, 'k') axarr[0].set_title("X bandpassed")
#axarr[1].set_title("Y bandpassed") axarr[1].plot(range(y_bandpassed.shape[0]), y_bandpassed, 'k')
#pyplot.show() axarr[1].set_title("Y bandpassed")
pyplot.show()
# build the final pulse signal # build the final pulse signal
alpha = numpy.std(x_bandpassed) / numpy.std(y_bandpassed) alpha = numpy.std(x_bandpassed) / numpy.std(y_bandpassed)
...@@ -206,11 +216,11 @@ class Chrom(Extractor, object): ...@@ -206,11 +216,11 @@ class Chrom(Extractor, object):
sw *= numpy.hanning(window_size) sw *= numpy.hanning(window_size)
pulse[w:w+window_size] += sw pulse[w:w+window_size] += sw
#from matplotlib import pyplot if self.debug:
#f, axarr = pyplot.subplots(1) from matplotlib import pyplot
#pyplot.plot(range(pulse.shape[0]), pulse, 'k') f, axarr = pyplot.subplots(1)
#pyplot.title("Pulse signal") pyplot.plot(range(pulse.shape[0]), pulse, 'k')
#pyplot.show() pyplot.title("Pulse signal")
pyplot.show()
#output_data = pulse
return pulse return pulse
...@@ -15,7 +15,6 @@ import bob.ip.skincolorfilter ...@@ -15,7 +15,6 @@ import bob.ip.skincolorfilter
import logging import logging
logger = logging.getLogger("bob.pad.face") logger = logging.getLogger("bob.pad.face")
from bob.rppg.base.utils import crop_face from bob.rppg.base.utils import crop_face
from bob.rppg.ssr.ssr_utils import get_eigen from bob.rppg.ssr.ssr_utils import get_eigen
...@@ -23,7 +22,6 @@ from bob.rppg.ssr.ssr_utils import plot_eigenvectors ...@@ -23,7 +22,6 @@ from bob.rppg.ssr.ssr_utils import plot_eigenvectors
from bob.rppg.ssr.ssr_utils import build_P from bob.rppg.ssr.ssr_utils import build_P
class SSR(Extractor, object): class SSR(Extractor, object):
""" """
Extract pulse signal according to the SSR algorithm Extract pulse signal according to the SSR algorithm
...@@ -39,8 +37,11 @@ class SSR(Extractor, object): ...@@ -39,8 +37,11 @@ class SSR(Extractor, object):
stride: int stride: int
The temporal stride. The temporal stride.
debug: boolean
Plot some stuff
""" """
def __init__(self, skin_threshold=0.5, skin_init=False, stride=25, **kwargs): def __init__(self, skin_threshold=0.5, skin_init=False, stride=25, debug=False, **kwargs):
super(SSR, self).__init__() super(SSR, self).__init__()
...@@ -62,20 +63,20 @@ class SSR(Extractor, object): ...@@ -62,20 +63,20 @@ class SSR(Extractor, object):
see ``bob.bio.video.utils.FrameContainer`` for further details. see ``bob.bio.video.utils.FrameContainer`` for further details.
If string, the name of the file to load the video data from is If string, the name of the file to load the video data from is
defined in it. String is possible only when empty preprocessor is defined in it. String is possible only when empty preprocessor is
used. In this case video data is loaded directly from the database. used. In this case video data is loaded directly from the database
and not using any high or low-level db packages (so beware).
**Returns:** **Returns:**
pulse: FrameContainer pulse: numpy.array
Quality Measures for each frame stored in the FrameContainer. The pulse signal
""" """
# load video based on the filename
assert isinstance(frames, six.string_types)
if isinstance(frames, six.string_types): if isinstance(frames, six.string_types):
video_loader = VideoDataLoader() video_loader = VideoDataLoader()
video = video_loader(frames) video = video_loader(frames)
else:
video = frames
video = video.as_array() video = video.as_array()
nb_frames = video.shape[0] nb_frames = video.shape[0]
...@@ -86,16 +87,19 @@ class SSR(Extractor, object): ...@@ -86,16 +87,19 @@ class SSR(Extractor, object):
eigenvalues = numpy.zeros((3, nb_frames), dtype='float64') eigenvalues = numpy.zeros((3, nb_frames), dtype='float64')
eigenvectors = numpy.zeros((3, 3, nb_frames), dtype='float64') eigenvectors = numpy.zeros((3, 3, nb_frames), dtype='float64')
### LET'S GO
#XXX
plot = True
counter = 0 counter = 0
previous_bbox = None previous_bbox = None
previous_skin_pixels = None
for i, frame in enumerate(video): for i, frame in enumerate(video):
logger.debug("Processing frame %d/%d...", i, nb_frames) logger.debug("Processing frame %d/%d...", i, nb_frames)
if self.debug:
from matplotlib import pyplot
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(frame, 2),2))
pyplot.show()
try: try:
bbox, quality = bob.ip.facedetect.detect_single_face(frame) bbox, quality = bob.ip.facedetect.detect_single_face(frame)
except: except:
...@@ -104,43 +108,54 @@ class SSR(Extractor, object): ...@@ -104,43 +108,54 @@ class SSR(Extractor, object):
face = crop_face(frame, bbox, bbox.size[1]) face = crop_face(frame, bbox, bbox.size[1])
#from matplotlib import pyplot if self.debug:
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(face, 2),2)) from matplotlib import pyplot
#pyplot.show() pyplot.imshow(numpy.rollaxis(numpy.rollaxis(face, 2),2))
pyplot.show()
# skin filter # skin filter
if counter == 0 or self.skin_init: if counter == 0 or self.skin_init:
self.skin_filter.estimate_gaussian_parameters(face) self.skin_filter.estimate_gaussian_parameters(face)
logger.debug("Skin color parameters:\nmean\n{0}\ncovariance\n{1}".format(self.skin_filter.mean, self.skin_filter.covariance)) logger.debug("Skin color parameters:\nmean\n{0}\ncovariance\n{1}".format(self.skin_filter.mean, self.skin_filter.covariance))
skin_mask = self.skin_filter.get_skin_mask(face, self.skin_threshold) skin_mask = self.skin_filter.get_skin_mask(face, self.skin_threshold)
skin_pixels = face[:, skin_mask] skin_pixels = face[:, skin_mask]
#from matplotlib import pyplot
#skin_mask_image = numpy.copy(face)
#skin_mask_image[:, skin_mask] = 255
#pyplot.title("skin pixels in frame {0}".format(i))
#pyplot.imshow(numpy.rollaxis(numpy.rollaxis(skin_mask_image, 2),2))
#pyplot.show()
skin_pixels = skin_pixels.astype('float64') / 255.0 skin_pixels = skin_pixels.astype('float64') / 255.0
if self.debug:
from matplotlib import pyplot
skin_mask_image = numpy.copy(face)
skin_mask_image[:, skin_mask] = 255
pyplot.title("skin pixels in frame {0}".format(i))
pyplot.imshow(numpy.rollaxis(numpy.rollaxis(skin_mask_image, 2),2))
pyplot.show()
# nos skin pixels have ben detected ... using the previous ones
if skin_pixels.shape[1] == 0:
skin_pixels = previous_skin_pixels
logger.warn("No skin pixels detected, using the previous ones")
# build c matrix and get eigenvectors and eigenvalues # build c matrix and get eigenvectors and eigenvalues
eigenvalues[:, counter], eigenvectors[:, :, counter] = get_eigen(skin_pixels) eigenvalues[:, counter], eigenvectors[:, :, counter] = get_eigen(skin_pixels)
#plot_eigenvectors(skin_pixels, eigenvectors[:, :, counter])
if self.debug:
plot_eigenvectors(skin_pixels, eigenvectors[:, :, counter])
# build P and add it to the pulse signal # build P and add it to the pulse signal
if counter >= self.stride: if counter >= self.stride:
tau = counter - self.stride tau = counter - self.stride
p = build_P(counter, self.stride, eigenvectors, eigenvalues) p = build_P(counter, self.stride, eigenvectors, eigenvalues)
output_data[tau:counter] += (p - numpy.mean(p)) output_data[tau:counter] += (p - numpy.mean(p))
previous_bbox = bbox
previous_skin_pixels = skin_pixels
counter += 1 counter += 1
# plot the pulse signal if self.debug:
#import matplotlib.pyplot as plt import matplotlib.pyplot as plt
#fig = plt.figure() fig = plt.figure()
#ax = fig.add_subplot(111) ax = fig.add_subplot(111)
#ax.plot(range(nb_frames), output_data) ax.plot(range(nb_frames), output_data)
#plt.show() plt.show()
return output_data return output_data
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment