Commit 1a2392d8 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Add more logging to the predict commands

parent d46cb5c1
......@@ -9,8 +9,11 @@ import sys
import logging
import click
from bob.extension.scripts.click_helper import (
verbosity_option, ConfigCommand, ResourceOption, log_parameters)
from multiprocessing import Pool
verbosity_option,
ConfigCommand,
ResourceOption,
log_parameters,
)
from collections import defaultdict
import numpy as np
import tensorflow as tf
......@@ -18,6 +21,7 @@ from bob.io.base import create_directories_safe
from bob.bio.base.utils import save
from bob.bio.base.tools.grid import indices
from bob.learn.tensorflow.dataset.bio import BioGenerator
try:
import bob.bio.video
except ModuleNotFoundError:
......@@ -42,7 +46,7 @@ def make_output_path(output_dir, key):
str
The path for the provided key.
"""
return os.path.join(output_dir, key + '.hdf5')
return os.path.join(output_dir, key + ".hdf5")
def non_existing_files(paths, force=False):
......@@ -69,110 +73,133 @@ def save_predictions(output_dir, key, pred_buffer, video_container):
save(data, outpath)
@click.command(
entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
@click.command(entry_point_group="bob.learn.tensorflow.config", cls=ConfigCommand)
@click.option(
'--estimator',
'-e',
"--estimator",
"-e",
required=True,
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.estimator',
help='The estimator that will be evaluated.')
entry_point_group="bob.learn.tensorflow.estimator",
help="The estimator that will be evaluated.",
)
@click.option(
'--database',
'-d',
"--database",
"-d",
required=True,
cls=ResourceOption,
entry_point_group='bob.bio.database',
help='A bio database. Its original_directory must point to the correct '
'path.')
entry_point_group="bob.bio.database",
help="A bio database. Its original_directory must point to the correct " "path.",
)
@click.option(
'--biofiles',
"--biofiles",
required=True,
cls=ResourceOption,
help='The list of the bio files. You can only provide this through config '
'files.')
help="The list of the bio files. You can only provide this through config "
"files.",
)
@click.option(
'--bio-predict-input-fn',
"--bio-predict-input-fn",
required=True,
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.biogenerator_input',
help='A callable with the signature of '
'`input_fn = bio_predict_input_fn(generator, output_types, output_shapes)`'
' The inputs are documented in :any:`tf.data.Dataset.from_generator`'
' and the output should be a function with no arguments and is passed'
' to :any:`tf.estimator.Estimator.predict`.')
entry_point_group="bob.learn.tensorflow.biogenerator_input",
help="A callable with the signature of "
"`input_fn = bio_predict_input_fn(generator, output_types, output_shapes)`"
" The inputs are documented in :any:`tf.data.Dataset.from_generator`"
" and the output should be a function with no arguments and is passed"
" to :any:`tf.estimator.Estimator.predict`.",
)
@click.option(
'--output-dir',
'-o',
"--output-dir",
"-o",
required=True,
cls=ResourceOption,
help='The directory to save the predictions.')
help="The directory to save the predictions.",
)
@click.option(
'--load-data',
"--load-data",
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.load_data',
help='A callable with the signature of '
'``data = load_data(database, biofile)``. '
':any:`bob.bio.base.read_original_data` is used by default.')
entry_point_group="bob.learn.tensorflow.load_data",
help="A callable with the signature of "
"``data = load_data(database, biofile)``. "
":any:`bob.bio.base.read_original_data` is used by default.",
)
@click.option(
'--hooks',
"--hooks",
cls=ResourceOption,
multiple=True,
entry_point_group='bob.learn.tensorflow.hook',
help='List of SessionRunHook subclass instances.')
entry_point_group="bob.learn.tensorflow.hook",
help="List of SessionRunHook subclass instances.",
)
@click.option(
'--predict-keys',
'-k',
"--predict-keys",
"-k",
multiple=True,
default=None,
cls=ResourceOption,
help='List of `str`, name of the keys to predict. It is used if the '
'`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used '
'then rest of the predictions will be filtered from the dictionary. '
'If `None`, returns all.')
help="List of `str`, name of the keys to predict. It is used if the "
"`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used "
"then rest of the predictions will be filtered from the dictionary. "
"If `None`, returns all.",
)
@click.option(
'--checkpoint-path',
'-c',
"--checkpoint-path",
"-c",
cls=ResourceOption,
help='Path of a specific checkpoint to predict. If `None`, the '
'latest checkpoint in `model_dir` is used. This can also '
help="Path of a specific checkpoint to predict. If `None`, the "
"latest checkpoint in `model_dir` is used. This can also "
'be a folder which contains a "checkpoint" file where the '
'latest checkpoint from inside this file will be used as '
'checkpoint_path.')
"latest checkpoint from inside this file will be used as "
"checkpoint_path.",
)
@click.option(
'--multiple-samples',
'-m',
"--multiple-samples",
"-m",
is_flag=True,
cls=ResourceOption,
help='If provided, it assumes that the db interface returns '
'several samples from a biofile. This option can be used '
'when you are working with videos.')
help="If provided, it assumes that the db interface returns "
"several samples from a biofile. This option can be used "
"when you are working with videos.",
)
@click.option(
'--array',
'-t',
"--array",
"-t",
type=click.INT,
default=1,
cls=ResourceOption,
help='Use this option alongside gridtk to submit this script as '
'an array job.')
help="Use this option alongside gridtk to submit this script as " "an array job.",
)
@click.option(
'--force',
'-f',
"--force",
"-f",
is_flag=True,
cls=ResourceOption,
help='Whether to overwrite existing predictions.')
help="Whether to overwrite existing predictions.",
)
@click.option(
'--video-container',
'-vc',
"--video-container",
"-vc",
is_flag=True,
cls=ResourceOption,
help='If provided, the predictions will be written in FrameContainers from'
' bob.bio.video. You need to install bob.bio.video as well.')
help="If provided, the predictions will be written in FrameContainers from"
" bob.bio.video. You need to install bob.bio.video as well.",
)
@verbosity_option(cls=ResourceOption)
def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
output_dir, load_data, hooks, predict_keys, checkpoint_path,
multiple_samples, array, force, video_container, **kwargs):
def predict_bio(
estimator,
database,
biofiles,
bio_predict_input_fn,
output_dir,
load_data,
hooks,
predict_keys,
checkpoint_path,
multiple_samples,
array,
force,
video_container,
**kwargs
):
"""Saves predictions or embeddings of tf.estimators.
This script works with bob.bio.base databases. This script works with
......@@ -209,7 +236,7 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
return {'data': images, 'key': keys}, labels
return input_fn
"""
log_parameters(logger, ignore=('biofiles', ))
log_parameters(logger, ignore=("biofiles",))
logger.debug("len(biofiles): %d", len(biofiles))
assert len(biofiles), "biofiles are empty!"
......@@ -219,98 +246,125 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
biofiles = biofiles[start:end]
# filter the existing files
paths = [
make_output_path(output_dir, f.make_path("", "")) for f in biofiles
]
paths = [make_output_path(output_dir, f.make_path("", "")) for f in biofiles]
indexes = non_existing_files(paths, force)
biofiles = [biofiles[i] for i in indexes]
if len(biofiles) == 0:
logger.warning(
"The biofiles are empty after checking for existing files.")
logger.warning("The biofiles are empty after checking for existing files.")
return
generator = BioGenerator(
database,
biofiles,
load_data=load_data,
multiple_samples=multiple_samples)
database, biofiles, load_data=load_data, multiple_samples=multiple_samples
)
predict_input_fn = bio_predict_input_fn(generator, generator.output_types,
generator.output_shapes)
predict_input_fn = bio_predict_input_fn(
generator, generator.output_types, generator.output_shapes
)
logger.info("Saving the predictions of %d files in %s", len(generator),
output_dir)
logger.info("Saving the predictions of %d files in %s", len(generator), output_dir)
generic_predict(
estimator, predict_input_fn, output_dir, predict_keys=predict_keys,
checkpoint_path=checkpoint_path, hooks=hooks,
video_container=video_container)
estimator=estimator,
predict_input_fn=predict_input_fn,
output_dir=output_dir,
predict_keys=predict_keys,
checkpoint_path=checkpoint_path,
hooks=hooks,
video_container=video_container,
)
@click.command(
entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
@click.command(entry_point_group="bob.learn.tensorflow.config", cls=ConfigCommand)
@click.option(
'--estimator',
'-e',
"--estimator",
"-e",
required=True,
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.estimator',
help='The estimator that will be evaluated.')
entry_point_group="bob.learn.tensorflow.estimator",
help="The estimator that will be evaluated.",
)
@click.option(
'--predict-input-fn',
"--predict-input-fn",
required=True,
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.input_fn',
help='A callable with no arguments which will be used in estimator.predict.')
entry_point_group="bob.learn.tensorflow.input_fn",
help="A callable with no arguments which will be used in estimator.predict.",
)
@click.option(
'--output-dir',
'-o',
"--output-dir",
"-o",
required=True,
cls=ResourceOption,
help='The directory to save the predictions.')
help="The directory to save the predictions.",
)
@click.option(
'--predict-keys',
'-k',
"--predict-keys",
"-k",
multiple=True,
default=None,
cls=ResourceOption,
help='List of `str`, name of the keys to predict. It is used if the '
'`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used '
'then rest of the predictions will be filtered from the dictionary. '
'If `None`, returns all.')
help="List of `str`, name of the keys to predict. It is used if the "
"`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used "
"then rest of the predictions will be filtered from the dictionary. "
"If `None`, returns all.",
)
@click.option(
'--checkpoint-path',
'-c',
"--checkpoint-path",
"-c",
cls=ResourceOption,
help='Path of a specific checkpoint to predict. If `None`, the '
'latest checkpoint in `model_dir` is used. This can also '
help="Path of a specific checkpoint to predict. If `None`, the "
"latest checkpoint in `model_dir` is used. This can also "
'be a folder which contains a "checkpoint" file where the '
'latest checkpoint from inside this file will be used as '
'checkpoint_path.')
"latest checkpoint from inside this file will be used as "
"checkpoint_path.",
)
@click.option(
'--hooks',
"--hooks",
cls=ResourceOption,
multiple=True,
entry_point_group='bob.learn.tensorflow.hook',
help='List of SessionRunHook subclass instances.')
entry_point_group="bob.learn.tensorflow.hook",
help="List of SessionRunHook subclass instances.",
)
@click.option(
'--video-container',
'-vc',
"--video-container",
"-vc",
is_flag=True,
cls=ResourceOption,
help='If provided, the predictions will be written in FrameContainers from'
' bob.bio.video. You need to install bob.bio.video as well.')
help="If provided, the predictions will be written in FrameContainers from"
" bob.bio.video. You need to install bob.bio.video as well.",
)
@verbosity_option(cls=ResourceOption)
def predict(estimator, predict_input_fn, output_dir, predict_keys,
checkpoint_path, hooks, video_container, **kwargs):
def predict(
estimator,
predict_input_fn,
output_dir,
predict_keys,
checkpoint_path,
hooks,
video_container,
**kwargs
):
generic_predict(
estimator, predict_input_fn, output_dir, predict_keys,
checkpoint_path, hooks, video_container)
estimator=estimator,
predict_input_fn=predict_input_fn,
output_dir=output_dir,
predict_keys=predict_keys,
checkpoint_path=checkpoint_path,
hooks=hooks,
video_container=video_container,
)
def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None,
checkpoint_path=None, hooks=None, video_container=False):
def generic_predict(
estimator,
predict_input_fn,
output_dir,
predict_keys=None,
checkpoint_path=None,
hooks=None,
video_container=False,
):
# if the checkpoint_path is a directory, pick the latest checkpoint from
# that directory
if checkpoint_path:
......@@ -333,16 +387,21 @@ def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None,
import bob.bio.video
except ModuleNotFoundError:
raise click.ClickException(
'Could not import bob.bio.video. Have you installed it?')
"Could not import bob.bio.video. Have you installed it?"
)
pred_buffer = defaultdict(list)
for i, pred in enumerate(predictions):
key = pred['key']
key = pred["key"]
# key is in bytes format in Python 3
if sys.version_info >= (3, ):
key = key.decode(errors='replace')
prob = pred.get('probabilities', pred.get(
'embeddings', pred.get('predictions')))
if sys.version_info >= (3,):
key = key.decode(errors="replace")
if predict_keys:
prob = pred[predict_keys[0]]
else:
prob = pred.get(
"probabilities", pred.get("embeddings", pred.get("predictions"))
)
assert prob is not None
pred_buffer[key].append(prob)
if i == 0:
......@@ -350,13 +409,17 @@ def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None,
if last_key == key:
continue
else:
save_predictions(
output_dir, last_key, pred_buffer, video_container)
save_predictions(output_dir, last_key, pred_buffer, video_container)
# delete saved data so we don't run out of RAM
del pred_buffer[last_key]
# start saving this new key
last_key = key
try:
key
# save the final returned key as well:
save_predictions(output_dir, key, pred_buffer, video_container)
except UnboundLocalError:
# if the input_fn was empty and hence key is not defined
click.echo("predict_input_fn returned no samples.")
pass
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment