Skip to content
Snippets Groups Projects
Commit 76160d43 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[chrom] added configuration file to extract pulse from mask with CHROM

parent 4479a260
No related branches found
No related tags found
1 merge request!6Resolve "Package depends on both bob.db.cohface and bob.db.hci_tagging"
...@@ -119,10 +119,7 @@ def main(user_input=None): ...@@ -119,10 +119,7 @@ def main(user_input=None):
from bob.core.log import set_verbosity_level from bob.core.log import set_verbosity_level
set_verbosity_level(logger, verbosity_level) set_verbosity_level(logger, verbosity_level)
print(configuration.database)
if hasattr(configuration, 'database'): if hasattr(configuration, 'database'):
print(protocol)
print(subset)
objects = configuration.database.objects(protocol, subset) objects = configuration.database.objects(protocol, subset)
else: else:
logger.error("Please provide a database in your configuration file !") logger.error("Please provide a database in your configuration file !")
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
"""Pulse extraction using CHROM algorithm (%(version)s) """Pulse extraction using CHROM algorithm (%(version)s)
Usage: Usage:
%(prog)s (cohface | hci) [--protocol=<string>] [--subset=<string> ...] %(prog)s <configuration>
[--dbdir=<path>] [--pulsedir=<path>] [--protocol=<string>] [--subset=<string> ...]
[--pulsedir=<path>]
[--npoints=<int>] [--indent=<int>] [--quality=<float>] [--distance=<int>] [--npoints=<int>] [--indent=<int>] [--quality=<float>] [--distance=<int>]
[--framerate=<int>] [--order=<int>] [--window=<int>] [--framerate=<int>] [--order=<int>] [--window=<int>]
[--overwrite] [--verbose ...] [--plot] [--gridcount] [--overwrite] [--verbose ...] [--plot] [--gridcount]
...@@ -17,13 +18,11 @@ Usage: ...@@ -17,13 +18,11 @@ Usage:
Options: Options:
-h, --help Show this screen -h, --help Show this screen
-V, --version Show version -V, --version Show version
-p, --protocol=<string> Protocol [default: all]. -p, --protocol=<string> Protocol.
-s, --subset=<string> Data subset to load. If nothing is provided -s, --subset=<string> Data subset to load. If nothing is provided
all the data sets will be loaded. all the data sets will be loaded.
-d, --dbdir=<path> The path to the database on your disk. If not set, -o, --pulsedir=<path> The path to the directory where signal extracted
defaults to Idiap standard locations. from the face area will be stored [default: pulse].
-f, --pulsedir=<path> The path to the directory where signal extracted
from the face area will be stored [default: face]
-n, --npoints=<int> Number of good features to track [default: 40] -n, --npoints=<int> Number of good features to track [default: 40]
-i, --indent=<int> Indent (in percent of the face width) to apply to -i, --indent=<int> Indent (in percent of the face width) to apply to
keypoints to get the mask [default: 10] keypoints to get the mask [default: 10]
...@@ -46,9 +45,9 @@ Options: ...@@ -46,9 +45,9 @@ Options:
Example: Example:
To run the pulse extractor for the cohface database To run the pulse extraction
$ %(prog)s cohface -v $ %(prog)s config.py -v
See '%(prog)s --help' for more information. See '%(prog)s --help' for more information.
...@@ -64,6 +63,9 @@ logger = setup("bob.rppg.base") ...@@ -64,6 +63,9 @@ logger = setup("bob.rppg.base")
from docopt import docopt from docopt import docopt
from bob.extension.config import load
from ...base.utils import get_parameter
version = pkg_resources.require('bob.rppg.base')[0].version version = pkg_resources.require('bob.rppg.base')[0].version
import numpy import numpy
...@@ -97,46 +99,35 @@ def main(user_input=None): ...@@ -97,46 +99,35 @@ def main(user_input=None):
completions = dict(prog=prog, version=version,) completions = dict(prog=prog, version=version,)
args = docopt(__doc__ % completions, argv=arguments, version='Signal extractor for videos (%s)' % version,) args = docopt(__doc__ % completions, argv=arguments, version='Signal extractor for videos (%s)' % version,)
# load configuration file
configuration = load([os.path.join(args['<configuration>'])])
# get various parameters, either from config file or command-line
protocol = get_parameter(args, configuration, 'protocol', 'None')
subset = get_parameter(args, configuration, 'subset', '')
pulsedir = get_parameter(args, configuration, 'pulsedir', 'pulse')
npoints = get_parameter(args, configuration, 'npoints', 40)
indent = get_parameter(args, configuration, 'indent', 10)
quality = get_parameter(args, configuration, 'quality', 0.01)
distance = get_parameter(args, configuration, 'distance', 10)
framerate = get_parameter(args, configuration, 'framerate', 61)
order = get_parameter(args, configuration, 'order', 128)
window = get_parameter(args, configuration, 'window', 0)
overwrite = get_parameter(args, configuration, 'overwrite', False)
plot = get_parameter(args, configuration, 'plot', False)
gridcount = get_parameter(args, configuration, 'gridcount', False)
verbosity_level = get_parameter(args, configuration, 'verbose', 0)
# if the user wants more verbosity, lowers the logging level # if the user wants more verbosity, lowers the logging level
from bob.core.log import set_verbosity_level from bob.core.log import set_verbosity_level
set_verbosity_level(logger, args['--verbose']) set_verbosity_level(logger, verbosity_level)
# chooses the database driver to use
if args['cohface']:
import bob.db.cohface
if os.path.isdir(bob.db.cohface.DATABASE_LOCATION):
logger.debug("Using Idiap default location for the DB")
dbdir = bob.db.cohface.DATABASE_LOCATION
elif args['--indir'] is not None:
logger.debug("Using provided location for the DB")
dbdir = args['--indir']
else:
logger.warn("Could not find the database directory, please provide one")
sys.exit()
db = bob.db.cohface.Database(dbdir)
if not((args['--protocol'] == 'all') or (args['--protocol'] == 'clean') or (args['--protocol'] == 'natural')):
logger.warning("Protocol should be either 'clean', 'natural' or 'all' (and not {0})".format(args['--protocol']))
sys.exit()
objects = db.objects(args['--protocol'], args['--subset'])
elif args['hci']:
import bob.db.hci_tagging
import bob.db.hci_tagging.driver
if os.path.isdir(bob.db.hci_tagging.driver.DATABASE_LOCATION):
logger.debug("Using Idiap default location for the DB")
dbdir = bob.db.hci_tagging.driver.DATABASE_LOCATION
elif args['--indir'] is not None:
logger.debug("Using provided location for the DB")
dbdir = args['--indir']
else:
logger.warn("Could not find the database directory, please provide one")
sys.exit()
db = bob.db.hci_tagging.Database()
if not((args['--protocol'] == 'all') or (args['--protocol'] == 'cvpr14')):
logger.warning("Protocol should be either 'all' or 'cvpr14' (and not {0})".format(args['--protocol']))
sys.exit()
objects = db.objects(args['--protocol'], args['--subset'])
if hasattr(configuration, 'database'):
objects = configuration.database.objects(protocol, subset)
else:
logger.error("Please provide a database in your configuration file !")
sys.exit()
# if we are on a grid environment, just find what I have to process. # if we are on a grid environment, just find what I have to process.
sge = False sge = False
try: try:
...@@ -150,27 +141,27 @@ def main(user_input=None): ...@@ -150,27 +141,27 @@ def main(user_input=None):
raise RuntimeError("Grid request for job {} on a setup with {} jobs".format(pos, len(objects))) raise RuntimeError("Grid request for job {} on a setup with {} jobs".format(pos, len(objects)))
objects = [objects[pos]] objects = [objects[pos]]
if args['--gridcount']: if gridcount:
print(len(objects)) print(len(objects))
sys.exit() sys.exit()
# build the bandpass filter one and for all # build the bandpass filter one and for all
bandpass_filter = build_bandpass_filter(float(args['--framerate']), int(args['--order']), bool(args['--plot'])) bandpass_filter = build_bandpass_filter(framerate, order, plot)
# does the actual work - for every video in the available dataset, # does the actual work - for every video in the available dataset,
# extract the signals and dumps the results to the corresponding directory # extract the signals and dumps the results to the corresponding directory
for obj in objects: for obj in objects:
# expected output file # expected output file
output = obj.make_path(args['--pulsedir'], '.hdf5') output = obj.make_path(pulsedir, '.hdf5')
# if output exists and not overwriting, skip this file # if output exists and not overwriting, skip this file
if (os.path.exists(output)) and not args['--overwrite']: if (os.path.exists(output)) and not overwrite:
logger.info("Skipping output file `%s': already exists, use --overwrite to force an overwrite", output) logger.info("Skipping output file `%s': already exists, use --overwrite to force an overwrite", output)
continue continue
# load video # load video
video = obj.load_video(dbdir) video = obj.load_video(configuration.dbdir)
logger.info("Processing input video from `%s'...", video.filename) logger.info("Processing input video from `%s'...", video.filename)
# number of frames # number of frames
...@@ -194,7 +185,7 @@ def main(user_input=None): ...@@ -194,7 +185,7 @@ def main(user_input=None):
# -> detect the face # -> detect the face
# -> get "good features" inside the face # -> get "good features" inside the face
kpts = obj.load_drmf_keypoints() kpts = obj.load_drmf_keypoints()
mask_points, mask = kp66_to_mask(frame, kpts, int(args['--indent']), bool(args['--plot'])) mask_points, mask = kp66_to_mask(frame, kpts, int(indent), plot)
try: try:
bbox = bounding_boxes[i] bbox = bounding_boxes[i]
...@@ -204,8 +195,7 @@ def main(user_input=None): ...@@ -204,8 +195,7 @@ def main(user_input=None):
# define the face width for the whole sequence # define the face width for the whole sequence
facewidth = bbox.size[1] facewidth = bbox.size[1]
face = crop_face(frame, bbox, facewidth) face = crop_face(frame, bbox, facewidth)
good_features = get_good_features_to_track(face,int(args['--npoints']), good_features = get_good_features_to_track(face,npoints, quality, distance, plot)
float(args['--quality']), int(args['--distance']), bool(args['--plot']))
else: else:
# subsequent frames: # subsequent frames:
# -> crop the face with the bounding_boxes of the previous frame (so # -> crop the face with the bounding_boxes of the previous frame (so
...@@ -215,7 +205,7 @@ def main(user_input=None): ...@@ -215,7 +205,7 @@ def main(user_input=None):
# current corners # current corners
# -> apply this transformation to the mask # -> apply this transformation to the mask
face = crop_face(frame, prev_bb, facewidth) face = crop_face(frame, prev_bb, facewidth)
good_features = track_features(prev_face, face, prev_features, bool(args['--plot'])) good_features = track_features(prev_face, face, prev_features, plot)
project = find_transformation(prev_features, good_features) project = find_transformation(prev_features, good_features)
if project is None: if project is None:
logger.warn("Sequence {0}, frame {1} : No projection was found" logger.warn("Sequence {0}, frame {1} : No projection was found"
...@@ -235,9 +225,7 @@ def main(user_input=None): ...@@ -235,9 +225,7 @@ def main(user_input=None):
prev_bb = bb prev_bb = bb
prev_face = crop_face(frame, prev_bb, facewidth) prev_face = crop_face(frame, prev_bb, facewidth)
prev_features = get_good_features_to_track(face, int(args['--npoints']), prev_features = get_good_features_to_track(face, npoints, quality, distance, plot)
float(args['--quality']), int(args['--distance']),
bool(args['--plot']))
if prev_features is None: if prev_features is None:
logger.warn("Sequence {0}, frame {1} No features to track" logger.warn("Sequence {0}, frame {1} No features to track"
" detected in the current frame, using the previous ones" " detected in the current frame, using the previous ones"
...@@ -246,7 +234,7 @@ def main(user_input=None): ...@@ -246,7 +234,7 @@ def main(user_input=None):
# get the mask # get the mask
face_mask = get_mask(frame, mask_points) face_mask = get_mask(frame, mask_points)
if bool(args['--plot']) and args['--verbose'] >= 2: if plot and verbosity_level >= 2:
from matplotlib import pyplot from matplotlib import pyplot
mask_image = numpy.copy(frame) mask_image = numpy.copy(frame)
mask_image[:, face_mask] = 255 mask_image[:, face_mask] = 255
...@@ -267,7 +255,7 @@ def main(user_input=None): ...@@ -267,7 +255,7 @@ def main(user_input=None):
x_bandpassed = filtfilt(bandpass_filter, numpy.array([1]), chrom[:, 0]) x_bandpassed = filtfilt(bandpass_filter, numpy.array([1]), chrom[:, 0])
y_bandpassed = filtfilt(bandpass_filter, numpy.array([1]), chrom[:, 1]) y_bandpassed = filtfilt(bandpass_filter, numpy.array([1]), chrom[:, 1])
if bool(args['--plot']): if plot:
from matplotlib import pyplot from matplotlib import pyplot
f, axarr = pyplot.subplots(2, sharex=True) f, axarr = pyplot.subplots(2, sharex=True)
axarr[0].plot(range(x_bandpassed.shape[0]), x_bandpassed, 'k') axarr[0].plot(range(x_bandpassed.shape[0]), x_bandpassed, 'k')
...@@ -281,8 +269,8 @@ def main(user_input=None): ...@@ -281,8 +269,8 @@ def main(user_input=None):
pulse = x_bandpassed - alpha * y_bandpassed pulse = x_bandpassed - alpha * y_bandpassed
# overlap-add if window_size != 0 # overlap-add if window_size != 0
if int(args['--window']) > 0: if int(window) > 0:
window_size = int(args['--window']) window_size = int(window)
window_stride = window_size / 2 window_stride = window_size / 2
for w in range(0, (len(pulse)-window_size), window_stride): for w in range(0, (len(pulse)-window_size), window_stride):
pulse[w:w+window_size] = 0.0 pulse[w:w+window_size] = 0.0
...@@ -293,7 +281,7 @@ def main(user_input=None): ...@@ -293,7 +281,7 @@ def main(user_input=None):
sw *= numpy.hanning(window_size) sw *= numpy.hanning(window_size)
pulse[w:w+window_size] += sw pulse[w:w+window_size] += sw
if bool(args['--plot']): if plot:
from matplotlib import pyplot from matplotlib import pyplot
f, axarr = pyplot.subplots(1) f, axarr = pyplot.subplots(1)
pyplot.plot(range(pulse.shape[0]), pulse, 'k') pyplot.plot(range(pulse.shape[0]), pulse, 'k')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment