Commit b9cc5a9a authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

handle nans in network predictions

parent 6e6addea
......@@ -7,7 +7,6 @@ from __future__ import print_function
import click
import logging
import os
import six
import shutil
import sys
import tensorflow as tf
......@@ -117,7 +116,7 @@ def read_evaluated_file(path):
def append_evaluated_file(path, evaluations):
str_evaluations = ', '.join(
'%s = %s' % (k, v) for k, v in sorted(six.iteritems(evaluations)))
'%s = %s' % (k, v) for k, v in sorted(evaluations.items()))
with open(path, 'a') as f:
f.write('{} {}\n'.format(evaluations['global_step'], str_evaluations))
return str_evaluations
......@@ -227,6 +226,7 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
if new_evaluated_count == evaluated_steps_count:
wait_interval_count += 1
if wait_interval_count > max_wait_intervals:
click.echo("Reached maximum wait interval!")
break
else:
evaluated_steps_count = new_evaluated_count
......
......@@ -59,16 +59,21 @@ def non_existing_files(paths, force=False):
yield i
def save_predictions(output_dir, key, pred_buffer, video_container):
def save_predictions(output_dir, key, pred_buffer, video_container, remove_nan=False):
outpath = make_output_path(output_dir, key)
create_directories_safe(os.path.dirname(outpath))
logger.debug("Saving predictions for %s", key)
if video_container:
fc = bob.bio.video.FrameContainer()
for i, v in enumerate(pred_buffer[key]):
if remove_nan and np.isnan(v):
continue
fc.add(i, v)
data = fc
else:
if remove_nan:
pred_buffer[key] = np.array(pred_buffer[key])
pred_buffer[key] = pred_buffer[key][~np.isnan(pred_buffer[key])]
data = np.mean(pred_buffer[key], axis=0)
save(data, outpath)
......@@ -183,6 +188,14 @@ def save_predictions(output_dir, key, pred_buffer, video_container):
help="If provided, the predictions will be written in FrameContainers from"
" bob.bio.video. You need to install bob.bio.video as well.",
)
@click.option(
"--remove-nan",
"-rn",
is_flag=True,
cls=ResourceOption,
help="If provided, will remove nans before computing the mean or remove nans "
"from the frame container.",
)
@verbosity_option(cls=ResourceOption)
def predict_bio(
estimator,
......@@ -198,6 +211,7 @@ def predict_bio(
array,
force,
video_container,
remove_nan,
**kwargs
):
"""Saves predictions or embeddings of tf.estimators.
......@@ -271,6 +285,7 @@ def predict_bio(
checkpoint_path=checkpoint_path,
hooks=hooks,
video_container=video_container,
remove_nan=remove_nan,
)
......@@ -333,6 +348,14 @@ def predict_bio(
help="If provided, the predictions will be written in FrameContainers from"
" bob.bio.video. You need to install bob.bio.video as well.",
)
@click.option(
"--remove-nan",
"-rn",
is_flag=True,
cls=ResourceOption,
help="If provided, will remove nans before computing the mean or remove nans "
"from the frame container.",
)
@verbosity_option(cls=ResourceOption)
def predict(
estimator,
......@@ -342,6 +365,7 @@ def predict(
checkpoint_path,
hooks,
video_container,
remove_nan,
**kwargs
):
......@@ -353,6 +377,7 @@ def predict(
checkpoint_path=checkpoint_path,
hooks=hooks,
video_container=video_container,
remove_nan=remove_nan,
)
......@@ -364,6 +389,7 @@ def generic_predict(
checkpoint_path=None,
hooks=None,
video_container=False,
remove_nan=False,
):
# if the checkpoint_path is a directory, pick the latest checkpoint from
# that directory
......@@ -409,7 +435,7 @@ def generic_predict(
if last_key == key:
continue
else:
save_predictions(output_dir, last_key, pred_buffer, video_container)
save_predictions(output_dir, last_key, pred_buffer, video_container, remove_nan)
# delete saved data so we don't run out of RAM
del pred_buffer[last_key]
# start saving this new key
......@@ -418,7 +444,7 @@ def generic_predict(
try:
key
# save the final returned key as well:
save_predictions(output_dir, key, pred_buffer, video_container)
save_predictions(output_dir, key, pred_buffer, video_container, remove_nan)
except UnboundLocalError:
# if the input_fn was empty and hence key is not defined
click.echo("predict_input_fn returned no samples.")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment