Skip to content
Snippets Groups Projects
Commit b9cc5a9a authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

handle nans in network predictions

parent 6e6addea
No related branches found
No related tags found
1 merge request!79Add keras-based models, add pixel-wise loss, other improvements
......@@ -7,7 +7,6 @@ from __future__ import print_function
import click
import logging
import os
import six
import shutil
import sys
import tensorflow as tf
......@@ -117,7 +116,7 @@ def read_evaluated_file(path):
def append_evaluated_file(path, evaluations):
str_evaluations = ', '.join(
'%s = %s' % (k, v) for k, v in sorted(six.iteritems(evaluations)))
'%s = %s' % (k, v) for k, v in sorted(evaluations.items()))
with open(path, 'a') as f:
f.write('{} {}\n'.format(evaluations['global_step'], str_evaluations))
return str_evaluations
......@@ -227,6 +226,7 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
if new_evaluated_count == evaluated_steps_count:
wait_interval_count += 1
if wait_interval_count > max_wait_intervals:
click.echo("Reached maximum wait interval!")
break
else:
evaluated_steps_count = new_evaluated_count
......
......@@ -59,16 +59,21 @@ def non_existing_files(paths, force=False):
yield i
def save_predictions(output_dir, key, pred_buffer, video_container):
def save_predictions(output_dir, key, pred_buffer, video_container, remove_nan=False):
outpath = make_output_path(output_dir, key)
create_directories_safe(os.path.dirname(outpath))
logger.debug("Saving predictions for %s", key)
if video_container:
fc = bob.bio.video.FrameContainer()
for i, v in enumerate(pred_buffer[key]):
if remove_nan and np.isnan(v):
continue
fc.add(i, v)
data = fc
else:
if remove_nan:
pred_buffer[key] = np.array(pred_buffer[key])
pred_buffer[key] = pred_buffer[key][~np.isnan(pred_buffer[key])]
data = np.mean(pred_buffer[key], axis=0)
save(data, outpath)
......@@ -183,6 +188,14 @@ def save_predictions(output_dir, key, pred_buffer, video_container):
help="If provided, the predictions will be written in FrameContainers from"
" bob.bio.video. You need to install bob.bio.video as well.",
)
@click.option(
"--remove-nan",
"-rn",
is_flag=True,
cls=ResourceOption,
help="If provided, will remove nans before computing the mean or remove nans "
"from the frame container.",
)
@verbosity_option(cls=ResourceOption)
def predict_bio(
estimator,
......@@ -198,6 +211,7 @@ def predict_bio(
array,
force,
video_container,
remove_nan,
**kwargs
):
"""Saves predictions or embeddings of tf.estimators.
......@@ -271,6 +285,7 @@ def predict_bio(
checkpoint_path=checkpoint_path,
hooks=hooks,
video_container=video_container,
remove_nan=remove_nan,
)
......@@ -333,6 +348,14 @@ def predict_bio(
help="If provided, the predictions will be written in FrameContainers from"
" bob.bio.video. You need to install bob.bio.video as well.",
)
@click.option(
"--remove-nan",
"-rn",
is_flag=True,
cls=ResourceOption,
help="If provided, will remove nans before computing the mean or remove nans "
"from the frame container.",
)
@verbosity_option(cls=ResourceOption)
def predict(
estimator,
......@@ -342,6 +365,7 @@ def predict(
checkpoint_path,
hooks,
video_container,
remove_nan,
**kwargs
):
......@@ -353,6 +377,7 @@ def predict(
checkpoint_path=checkpoint_path,
hooks=hooks,
video_container=video_container,
remove_nan=remove_nan,
)
......@@ -364,6 +389,7 @@ def generic_predict(
checkpoint_path=None,
hooks=None,
video_container=False,
remove_nan=False,
):
# if the checkpoint_path is a directory, pick the latest checkpoint from
# that directory
......@@ -409,7 +435,7 @@ def generic_predict(
if last_key == key:
continue
else:
save_predictions(output_dir, last_key, pred_buffer, video_container)
save_predictions(output_dir, last_key, pred_buffer, video_container, remove_nan)
# delete saved data so we don't run out of RAM
del pred_buffer[last_key]
# start saving this new key
......@@ -418,7 +444,7 @@ def generic_predict(
try:
key
# save the final returned key as well:
save_predictions(output_dir, key, pred_buffer, video_container)
save_predictions(output_dir, key, pred_buffer, video_container, remove_nan)
except UnboundLocalError:
# if the input_fn was empty and hence key is not defined
click.echo("predict_input_fn returned no samples.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment