diff --git a/bob/rppg/chrom/script/extract_pulse.py b/bob/rppg/chrom/script/extract_pulse.py index 45fd04963c5a3b0fb563a7dd170030a1b0f376a4..74d74c32338ea98ddb824e15abd76bda8661c29f 100644 --- a/bob/rppg/chrom/script/extract_pulse.py +++ b/bob/rppg/chrom/script/extract_pulse.py @@ -119,10 +119,7 @@ def main(user_input=None): from bob.core.log import set_verbosity_level set_verbosity_level(logger, verbosity_level) - print(configuration.database) if hasattr(configuration, 'database'): - print(protocol) - print(subset) objects = configuration.database.objects(protocol, subset) else: logger.error("Please provide a database in your configuration file !") diff --git a/bob/rppg/chrom/script/extract_pulse_from_mask.py b/bob/rppg/chrom/script/extract_pulse_from_mask.py index 0ae3748474c4ef850b6f31ce86614b474223cadd..6a25c4e4e1eebcfcb57d65aef503396c78a0e613 100644 --- a/bob/rppg/chrom/script/extract_pulse_from_mask.py +++ b/bob/rppg/chrom/script/extract_pulse_from_mask.py @@ -4,8 +4,9 @@ """Pulse extraction using CHROM algorithm (%(version)s) Usage: - %(prog)s (cohface | hci) [--protocol=<string>] [--subset=<string> ...] - [--dbdir=<path>] [--pulsedir=<path>] + %(prog)s <configuration> + [--protocol=<string>] [--subset=<string> ...] + [--pulsedir=<path>] [--npoints=<int>] [--indent=<int>] [--quality=<float>] [--distance=<int>] [--framerate=<int>] [--order=<int>] [--window=<int>] [--overwrite] [--verbose ...] [--plot] [--gridcount] @@ -17,13 +18,11 @@ Usage: Options: -h, --help Show this screen -V, --version Show version - -p, --protocol=<string> Protocol [default: all]. + -p, --protocol=<string> Protocol. -s, --subset=<string> Data subset to load. If nothing is provided all the data sets will be loaded. - -d, --dbdir=<path> The path to the database on your disk. If not set, - defaults to Idiap standard locations. - -f, --pulsedir=<path> The path to the directory where signal extracted - from the face area will be stored [default: face] + -o, --pulsedir=<path> The path to the directory where signal extracted + from the face area will be stored [default: pulse]. -n, --npoints=<int> Number of good features to track [default: 40] -i, --indent=<int> Indent (in percent of the face width) to apply to keypoints to get the mask [default: 10] @@ -46,9 +45,9 @@ Options: Example: - To run the pulse extractor for the cohface database + To run the pulse extraction - $ %(prog)s cohface -v + $ %(prog)s config.py -v See '%(prog)s --help' for more information. @@ -64,6 +63,9 @@ logger = setup("bob.rppg.base") from docopt import docopt +from bob.extension.config import load +from ...base.utils import get_parameter + version = pkg_resources.require('bob.rppg.base')[0].version import numpy @@ -97,46 +99,35 @@ def main(user_input=None): completions = dict(prog=prog, version=version,) args = docopt(__doc__ % completions, argv=arguments, version='Signal extractor for videos (%s)' % version,) + # load configuration file + configuration = load([os.path.join(args['<configuration>'])]) + + # get various parameters, either from config file or command-line + protocol = get_parameter(args, configuration, 'protocol', 'None') + subset = get_parameter(args, configuration, 'subset', '') + pulsedir = get_parameter(args, configuration, 'pulsedir', 'pulse') + npoints = get_parameter(args, configuration, 'npoints', 40) + indent = get_parameter(args, configuration, 'indent', 10) + quality = get_parameter(args, configuration, 'quality', 0.01) + distance = get_parameter(args, configuration, 'distance', 10) + framerate = get_parameter(args, configuration, 'framerate', 61) + order = get_parameter(args, configuration, 'order', 128) + window = get_parameter(args, configuration, 'window', 0) + overwrite = get_parameter(args, configuration, 'overwrite', False) + plot = get_parameter(args, configuration, 'plot', False) + gridcount = get_parameter(args, configuration, 'gridcount', False) + verbosity_level = get_parameter(args, configuration, 'verbose', 0) + # if the user wants more verbosity, lowers the logging level from bob.core.log import set_verbosity_level - set_verbosity_level(logger, args['--verbose']) - - # chooses the database driver to use - if args['cohface']: - import bob.db.cohface - if os.path.isdir(bob.db.cohface.DATABASE_LOCATION): - logger.debug("Using Idiap default location for the DB") - dbdir = bob.db.cohface.DATABASE_LOCATION - elif args['--indir'] is not None: - logger.debug("Using provided location for the DB") - dbdir = args['--indir'] - else: - logger.warn("Could not find the database directory, please provide one") - sys.exit() - db = bob.db.cohface.Database(dbdir) - if not((args['--protocol'] == 'all') or (args['--protocol'] == 'clean') or (args['--protocol'] == 'natural')): - logger.warning("Protocol should be either 'clean', 'natural' or 'all' (and not {0})".format(args['--protocol'])) - sys.exit() - objects = db.objects(args['--protocol'], args['--subset']) - - elif args['hci']: - import bob.db.hci_tagging - import bob.db.hci_tagging.driver - if os.path.isdir(bob.db.hci_tagging.driver.DATABASE_LOCATION): - logger.debug("Using Idiap default location for the DB") - dbdir = bob.db.hci_tagging.driver.DATABASE_LOCATION - elif args['--indir'] is not None: - logger.debug("Using provided location for the DB") - dbdir = args['--indir'] - else: - logger.warn("Could not find the database directory, please provide one") - sys.exit() - db = bob.db.hci_tagging.Database() - if not((args['--protocol'] == 'all') or (args['--protocol'] == 'cvpr14')): - logger.warning("Protocol should be either 'all' or 'cvpr14' (and not {0})".format(args['--protocol'])) - sys.exit() - objects = db.objects(args['--protocol'], args['--subset']) + set_verbosity_level(logger, verbosity_level) + if hasattr(configuration, 'database'): + objects = configuration.database.objects(protocol, subset) + else: + logger.error("Please provide a database in your configuration file !") + sys.exit() + # if we are on a grid environment, just find what I have to process. sge = False try: @@ -150,27 +141,27 @@ def main(user_input=None): raise RuntimeError("Grid request for job {} on a setup with {} jobs".format(pos, len(objects))) objects = [objects[pos]] - if args['--gridcount']: + if gridcount: print(len(objects)) sys.exit() # build the bandpass filter one and for all - bandpass_filter = build_bandpass_filter(float(args['--framerate']), int(args['--order']), bool(args['--plot'])) + bandpass_filter = build_bandpass_filter(framerate, order, plot) # does the actual work - for every video in the available dataset, # extract the signals and dumps the results to the corresponding directory for obj in objects: # expected output file - output = obj.make_path(args['--pulsedir'], '.hdf5') + output = obj.make_path(pulsedir, '.hdf5') # if output exists and not overwriting, skip this file - if (os.path.exists(output)) and not args['--overwrite']: + if (os.path.exists(output)) and not overwrite: logger.info("Skipping output file `%s': already exists, use --overwrite to force an overwrite", output) continue # load video - video = obj.load_video(dbdir) + video = obj.load_video(configuration.dbdir) logger.info("Processing input video from `%s'...", video.filename) # number of frames @@ -194,7 +185,7 @@ def main(user_input=None): # -> detect the face # -> get "good features" inside the face kpts = obj.load_drmf_keypoints() - mask_points, mask = kp66_to_mask(frame, kpts, int(args['--indent']), bool(args['--plot'])) + mask_points, mask = kp66_to_mask(frame, kpts, int(indent), plot) try: bbox = bounding_boxes[i] @@ -204,8 +195,7 @@ def main(user_input=None): # define the face width for the whole sequence facewidth = bbox.size[1] face = crop_face(frame, bbox, facewidth) - good_features = get_good_features_to_track(face,int(args['--npoints']), - float(args['--quality']), int(args['--distance']), bool(args['--plot'])) + good_features = get_good_features_to_track(face,npoints, quality, distance, plot) else: # subsequent frames: # -> crop the face with the bounding_boxes of the previous frame (so @@ -215,7 +205,7 @@ def main(user_input=None): # current corners # -> apply this transformation to the mask face = crop_face(frame, prev_bb, facewidth) - good_features = track_features(prev_face, face, prev_features, bool(args['--plot'])) + good_features = track_features(prev_face, face, prev_features, plot) project = find_transformation(prev_features, good_features) if project is None: logger.warn("Sequence {0}, frame {1} : No projection was found" @@ -235,9 +225,7 @@ def main(user_input=None): prev_bb = bb prev_face = crop_face(frame, prev_bb, facewidth) - prev_features = get_good_features_to_track(face, int(args['--npoints']), - float(args['--quality']), int(args['--distance']), - bool(args['--plot'])) + prev_features = get_good_features_to_track(face, npoints, quality, distance, plot) if prev_features is None: logger.warn("Sequence {0}, frame {1} No features to track" " detected in the current frame, using the previous ones" @@ -246,7 +234,7 @@ def main(user_input=None): # get the mask face_mask = get_mask(frame, mask_points) - if bool(args['--plot']) and args['--verbose'] >= 2: + if plot and verbosity_level >= 2: from matplotlib import pyplot mask_image = numpy.copy(frame) mask_image[:, face_mask] = 255 @@ -267,7 +255,7 @@ def main(user_input=None): x_bandpassed = filtfilt(bandpass_filter, numpy.array([1]), chrom[:, 0]) y_bandpassed = filtfilt(bandpass_filter, numpy.array([1]), chrom[:, 1]) - if bool(args['--plot']): + if plot: from matplotlib import pyplot f, axarr = pyplot.subplots(2, sharex=True) axarr[0].plot(range(x_bandpassed.shape[0]), x_bandpassed, 'k') @@ -281,8 +269,8 @@ def main(user_input=None): pulse = x_bandpassed - alpha * y_bandpassed # overlap-add if window_size != 0 - if int(args['--window']) > 0: - window_size = int(args['--window']) + if int(window) > 0: + window_size = int(window) window_stride = window_size / 2 for w in range(0, (len(pulse)-window_size), window_stride): pulse[w:w+window_size] = 0.0 @@ -293,7 +281,7 @@ def main(user_input=None): sw *= numpy.hanning(window_size) pulse[w:w+window_size] += sw - if bool(args['--plot']): + if plot: from matplotlib import pyplot f, axarr = pyplot.subplots(1) pyplot.plot(range(pulse.shape[0]), pulse, 'k')