Commit b9cc5a9a authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

handle nans in network predictions

parent 6e6addea
...@@ -7,7 +7,6 @@ from __future__ import print_function ...@@ -7,7 +7,6 @@ from __future__ import print_function
import click import click
import logging import logging
import os import os
import six
import shutil import shutil
import sys import sys
import tensorflow as tf import tensorflow as tf
...@@ -117,7 +116,7 @@ def read_evaluated_file(path): ...@@ -117,7 +116,7 @@ def read_evaluated_file(path):
def append_evaluated_file(path, evaluations): def append_evaluated_file(path, evaluations):
str_evaluations = ', '.join( str_evaluations = ', '.join(
'%s = %s' % (k, v) for k, v in sorted(six.iteritems(evaluations))) '%s = %s' % (k, v) for k, v in sorted(evaluations.items()))
with open(path, 'a') as f: with open(path, 'a') as f:
f.write('{} {}\n'.format(evaluations['global_step'], str_evaluations)) f.write('{} {}\n'.format(evaluations['global_step'], str_evaluations))
return str_evaluations return str_evaluations
...@@ -227,6 +226,7 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name, ...@@ -227,6 +226,7 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
if new_evaluated_count == evaluated_steps_count: if new_evaluated_count == evaluated_steps_count:
wait_interval_count += 1 wait_interval_count += 1
if wait_interval_count > max_wait_intervals: if wait_interval_count > max_wait_intervals:
click.echo("Reached maximum wait interval!")
break break
else: else:
evaluated_steps_count = new_evaluated_count evaluated_steps_count = new_evaluated_count
......
...@@ -59,16 +59,21 @@ def non_existing_files(paths, force=False): ...@@ -59,16 +59,21 @@ def non_existing_files(paths, force=False):
yield i yield i
def save_predictions(output_dir, key, pred_buffer, video_container): def save_predictions(output_dir, key, pred_buffer, video_container, remove_nan=False):
outpath = make_output_path(output_dir, key) outpath = make_output_path(output_dir, key)
create_directories_safe(os.path.dirname(outpath)) create_directories_safe(os.path.dirname(outpath))
logger.debug("Saving predictions for %s", key) logger.debug("Saving predictions for %s", key)
if video_container: if video_container:
fc = bob.bio.video.FrameContainer() fc = bob.bio.video.FrameContainer()
for i, v in enumerate(pred_buffer[key]): for i, v in enumerate(pred_buffer[key]):
if remove_nan and np.isnan(v):
continue
fc.add(i, v) fc.add(i, v)
data = fc data = fc
else: else:
if remove_nan:
pred_buffer[key] = np.array(pred_buffer[key])
pred_buffer[key] = pred_buffer[key][~np.isnan(pred_buffer[key])]
data = np.mean(pred_buffer[key], axis=0) data = np.mean(pred_buffer[key], axis=0)
save(data, outpath) save(data, outpath)
...@@ -183,6 +188,14 @@ def save_predictions(output_dir, key, pred_buffer, video_container): ...@@ -183,6 +188,14 @@ def save_predictions(output_dir, key, pred_buffer, video_container):
help="If provided, the predictions will be written in FrameContainers from" help="If provided, the predictions will be written in FrameContainers from"
" bob.bio.video. You need to install bob.bio.video as well.", " bob.bio.video. You need to install bob.bio.video as well.",
) )
@click.option(
"--remove-nan",
"-rn",
is_flag=True,
cls=ResourceOption,
help="If provided, will remove nans before computing the mean or remove nans "
"from the frame container.",
)
@verbosity_option(cls=ResourceOption) @verbosity_option(cls=ResourceOption)
def predict_bio( def predict_bio(
estimator, estimator,
...@@ -198,6 +211,7 @@ def predict_bio( ...@@ -198,6 +211,7 @@ def predict_bio(
array, array,
force, force,
video_container, video_container,
remove_nan,
**kwargs **kwargs
): ):
"""Saves predictions or embeddings of tf.estimators. """Saves predictions or embeddings of tf.estimators.
...@@ -271,6 +285,7 @@ def predict_bio( ...@@ -271,6 +285,7 @@ def predict_bio(
checkpoint_path=checkpoint_path, checkpoint_path=checkpoint_path,
hooks=hooks, hooks=hooks,
video_container=video_container, video_container=video_container,
remove_nan=remove_nan,
) )
...@@ -333,6 +348,14 @@ def predict_bio( ...@@ -333,6 +348,14 @@ def predict_bio(
help="If provided, the predictions will be written in FrameContainers from" help="If provided, the predictions will be written in FrameContainers from"
" bob.bio.video. You need to install bob.bio.video as well.", " bob.bio.video. You need to install bob.bio.video as well.",
) )
@click.option(
"--remove-nan",
"-rn",
is_flag=True,
cls=ResourceOption,
help="If provided, will remove nans before computing the mean or remove nans "
"from the frame container.",
)
@verbosity_option(cls=ResourceOption) @verbosity_option(cls=ResourceOption)
def predict( def predict(
estimator, estimator,
...@@ -342,6 +365,7 @@ def predict( ...@@ -342,6 +365,7 @@ def predict(
checkpoint_path, checkpoint_path,
hooks, hooks,
video_container, video_container,
remove_nan,
**kwargs **kwargs
): ):
...@@ -353,6 +377,7 @@ def predict( ...@@ -353,6 +377,7 @@ def predict(
checkpoint_path=checkpoint_path, checkpoint_path=checkpoint_path,
hooks=hooks, hooks=hooks,
video_container=video_container, video_container=video_container,
remove_nan=remove_nan,
) )
...@@ -364,6 +389,7 @@ def generic_predict( ...@@ -364,6 +389,7 @@ def generic_predict(
checkpoint_path=None, checkpoint_path=None,
hooks=None, hooks=None,
video_container=False, video_container=False,
remove_nan=False,
): ):
# if the checkpoint_path is a directory, pick the latest checkpoint from # if the checkpoint_path is a directory, pick the latest checkpoint from
# that directory # that directory
...@@ -409,7 +435,7 @@ def generic_predict( ...@@ -409,7 +435,7 @@ def generic_predict(
if last_key == key: if last_key == key:
continue continue
else: else:
save_predictions(output_dir, last_key, pred_buffer, video_container) save_predictions(output_dir, last_key, pred_buffer, video_container, remove_nan)
# delete saved data so we don't run out of RAM # delete saved data so we don't run out of RAM
del pred_buffer[last_key] del pred_buffer[last_key]
# start saving this new key # start saving this new key
...@@ -418,7 +444,7 @@ def generic_predict( ...@@ -418,7 +444,7 @@ def generic_predict(
try: try:
key key
# save the final returned key as well: # save the final returned key as well:
save_predictions(output_dir, key, pred_buffer, video_container) save_predictions(output_dir, key, pred_buffer, video_container, remove_nan)
except UnboundLocalError: except UnboundLocalError:
# if the input_fn was empty and hence key is not defined # if the input_fn was empty and hence key is not defined
click.echo("predict_input_fn returned no samples.") click.echo("predict_input_fn returned no samples.")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment