Skip to content
Snippets Groups Projects
Commit 281e1c26 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add more logging to the predict commands

parent c82154ef
No related branches found
No related tags found
No related merge requests found
......@@ -9,8 +9,11 @@ import sys
import logging
import click
from bob.extension.scripts.click_helper import (
verbosity_option, ConfigCommand, ResourceOption, log_parameters)
from multiprocessing import Pool
verbosity_option,
ConfigCommand,
ResourceOption,
log_parameters,
)
from collections import defaultdict
import numpy as np
import tensorflow as tf
......@@ -18,6 +21,7 @@ from bob.io.base import create_directories_safe
from bob.bio.base.utils import save
from bob.bio.base.tools.grid import indices
from bob.learn.tensorflow.dataset.bio import BioGenerator
try:
import bob.bio.video
except ModuleNotFoundError:
......@@ -42,7 +46,7 @@ def make_output_path(output_dir, key):
str
The path for the provided key.
"""
return os.path.join(output_dir, key + '.hdf5')
return os.path.join(output_dir, key + ".hdf5")
def non_existing_files(paths, force=False):
......@@ -69,110 +73,133 @@ def save_predictions(output_dir, key, pred_buffer, video_container):
save(data, outpath)
@click.command(
entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
@click.command(entry_point_group="bob.learn.tensorflow.config", cls=ConfigCommand)
@click.option(
'--estimator',
'-e',
"--estimator",
"-e",
required=True,
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.estimator',
help='The estimator that will be evaluated.')
entry_point_group="bob.learn.tensorflow.estimator",
help="The estimator that will be evaluated.",
)
@click.option(
'--database',
'-d',
"--database",
"-d",
required=True,
cls=ResourceOption,
entry_point_group='bob.bio.database',
help='A bio database. Its original_directory must point to the correct '
'path.')
entry_point_group="bob.bio.database",
help="A bio database. Its original_directory must point to the correct " "path.",
)
@click.option(
'--biofiles',
"--biofiles",
required=True,
cls=ResourceOption,
help='The list of the bio files. You can only provide this through config '
'files.')
help="The list of the bio files. You can only provide this through config "
"files.",
)
@click.option(
'--bio-predict-input-fn',
"--bio-predict-input-fn",
required=True,
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.biogenerator_input',
help='A callable with the signature of '
'`input_fn = bio_predict_input_fn(generator, output_types, output_shapes)`'
' The inputs are documented in :any:`tf.data.Dataset.from_generator`'
' and the output should be a function with no arguments and is passed'
' to :any:`tf.estimator.Estimator.predict`.')
entry_point_group="bob.learn.tensorflow.biogenerator_input",
help="A callable with the signature of "
"`input_fn = bio_predict_input_fn(generator, output_types, output_shapes)`"
" The inputs are documented in :any:`tf.data.Dataset.from_generator`"
" and the output should be a function with no arguments and is passed"
" to :any:`tf.estimator.Estimator.predict`.",
)
@click.option(
'--output-dir',
'-o',
"--output-dir",
"-o",
required=True,
cls=ResourceOption,
help='The directory to save the predictions.')
help="The directory to save the predictions.",
)
@click.option(
'--load-data',
"--load-data",
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.load_data',
help='A callable with the signature of '
'``data = load_data(database, biofile)``. '
':any:`bob.bio.base.read_original_data` is used by default.')
entry_point_group="bob.learn.tensorflow.load_data",
help="A callable with the signature of "
"``data = load_data(database, biofile)``. "
":any:`bob.bio.base.read_original_data` is used by default.",
)
@click.option(
'--hooks',
"--hooks",
cls=ResourceOption,
multiple=True,
entry_point_group='bob.learn.tensorflow.hook',
help='List of SessionRunHook subclass instances.')
entry_point_group="bob.learn.tensorflow.hook",
help="List of SessionRunHook subclass instances.",
)
@click.option(
'--predict-keys',
'-k',
"--predict-keys",
"-k",
multiple=True,
default=None,
cls=ResourceOption,
help='List of `str`, name of the keys to predict. It is used if the '
'`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used '
'then rest of the predictions will be filtered from the dictionary. '
'If `None`, returns all.')
help="List of `str`, name of the keys to predict. It is used if the "
"`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used "
"then rest of the predictions will be filtered from the dictionary. "
"If `None`, returns all.",
)
@click.option(
'--checkpoint-path',
'-c',
"--checkpoint-path",
"-c",
cls=ResourceOption,
help='Path of a specific checkpoint to predict. If `None`, the '
'latest checkpoint in `model_dir` is used. This can also '
help="Path of a specific checkpoint to predict. If `None`, the "
"latest checkpoint in `model_dir` is used. This can also "
'be a folder which contains a "checkpoint" file where the '
'latest checkpoint from inside this file will be used as '
'checkpoint_path.')
"latest checkpoint from inside this file will be used as "
"checkpoint_path.",
)
@click.option(
'--multiple-samples',
'-m',
"--multiple-samples",
"-m",
is_flag=True,
cls=ResourceOption,
help='If provided, it assumes that the db interface returns '
'several samples from a biofile. This option can be used '
'when you are working with videos.')
help="If provided, it assumes that the db interface returns "
"several samples from a biofile. This option can be used "
"when you are working with videos.",
)
@click.option(
'--array',
'-t',
"--array",
"-t",
type=click.INT,
default=1,
cls=ResourceOption,
help='Use this option alongside gridtk to submit this script as '
'an array job.')
help="Use this option alongside gridtk to submit this script as " "an array job.",
)
@click.option(
'--force',
'-f',
"--force",
"-f",
is_flag=True,
cls=ResourceOption,
help='Whether to overwrite existing predictions.')
help="Whether to overwrite existing predictions.",
)
@click.option(
'--video-container',
'-vc',
"--video-container",
"-vc",
is_flag=True,
cls=ResourceOption,
help='If provided, the predictions will be written in FrameContainers from'
' bob.bio.video. You need to install bob.bio.video as well.')
help="If provided, the predictions will be written in FrameContainers from"
" bob.bio.video. You need to install bob.bio.video as well.",
)
@verbosity_option(cls=ResourceOption)
def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
output_dir, load_data, hooks, predict_keys, checkpoint_path,
multiple_samples, array, force, video_container, **kwargs):
def predict_bio(
estimator,
database,
biofiles,
bio_predict_input_fn,
output_dir,
load_data,
hooks,
predict_keys,
checkpoint_path,
multiple_samples,
array,
force,
video_container,
**kwargs
):
"""Saves predictions or embeddings of tf.estimators.
This script works with bob.bio.base databases. This script works with
......@@ -209,7 +236,7 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
return {'data': images, 'key': keys}, labels
return input_fn
"""
log_parameters(logger, ignore=('biofiles', ))
log_parameters(logger, ignore=("biofiles",))
logger.debug("len(biofiles): %d", len(biofiles))
assert len(biofiles), "biofiles are empty!"
......@@ -219,98 +246,125 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
biofiles = biofiles[start:end]
# filter the existing files
paths = [
make_output_path(output_dir, f.make_path("", "")) for f in biofiles
]
paths = [make_output_path(output_dir, f.make_path("", "")) for f in biofiles]
indexes = non_existing_files(paths, force)
biofiles = [biofiles[i] for i in indexes]
if len(biofiles) == 0:
logger.warning(
"The biofiles are empty after checking for existing files.")
logger.warning("The biofiles are empty after checking for existing files.")
return
generator = BioGenerator(
database,
biofiles,
load_data=load_data,
multiple_samples=multiple_samples)
database, biofiles, load_data=load_data, multiple_samples=multiple_samples
)
predict_input_fn = bio_predict_input_fn(generator, generator.output_types,
generator.output_shapes)
predict_input_fn = bio_predict_input_fn(
generator, generator.output_types, generator.output_shapes
)
logger.info("Saving the predictions of %d files in %s", len(generator),
output_dir)
logger.info("Saving the predictions of %d files in %s", len(generator), output_dir)
generic_predict(
estimator, predict_input_fn, output_dir, predict_keys=predict_keys,
checkpoint_path=checkpoint_path, hooks=hooks,
video_container=video_container)
estimator=estimator,
predict_input_fn=predict_input_fn,
output_dir=output_dir,
predict_keys=predict_keys,
checkpoint_path=checkpoint_path,
hooks=hooks,
video_container=video_container,
)
@click.command(
entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
@click.command(entry_point_group="bob.learn.tensorflow.config", cls=ConfigCommand)
@click.option(
'--estimator',
'-e',
"--estimator",
"-e",
required=True,
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.estimator',
help='The estimator that will be evaluated.')
entry_point_group="bob.learn.tensorflow.estimator",
help="The estimator that will be evaluated.",
)
@click.option(
'--predict-input-fn',
"--predict-input-fn",
required=True,
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.input_fn',
help='A callable with no arguments which will be used in estimator.predict.')
entry_point_group="bob.learn.tensorflow.input_fn",
help="A callable with no arguments which will be used in estimator.predict.",
)
@click.option(
'--output-dir',
'-o',
"--output-dir",
"-o",
required=True,
cls=ResourceOption,
help='The directory to save the predictions.')
help="The directory to save the predictions.",
)
@click.option(
'--predict-keys',
'-k',
"--predict-keys",
"-k",
multiple=True,
default=None,
cls=ResourceOption,
help='List of `str`, name of the keys to predict. It is used if the '
'`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used '
'then rest of the predictions will be filtered from the dictionary. '
'If `None`, returns all.')
help="List of `str`, name of the keys to predict. It is used if the "
"`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used "
"then rest of the predictions will be filtered from the dictionary. "
"If `None`, returns all.",
)
@click.option(
'--checkpoint-path',
'-c',
"--checkpoint-path",
"-c",
cls=ResourceOption,
help='Path of a specific checkpoint to predict. If `None`, the '
'latest checkpoint in `model_dir` is used. This can also '
help="Path of a specific checkpoint to predict. If `None`, the "
"latest checkpoint in `model_dir` is used. This can also "
'be a folder which contains a "checkpoint" file where the '
'latest checkpoint from inside this file will be used as '
'checkpoint_path.')
"latest checkpoint from inside this file will be used as "
"checkpoint_path.",
)
@click.option(
'--hooks',
"--hooks",
cls=ResourceOption,
multiple=True,
entry_point_group='bob.learn.tensorflow.hook',
help='List of SessionRunHook subclass instances.')
entry_point_group="bob.learn.tensorflow.hook",
help="List of SessionRunHook subclass instances.",
)
@click.option(
'--video-container',
'-vc',
"--video-container",
"-vc",
is_flag=True,
cls=ResourceOption,
help='If provided, the predictions will be written in FrameContainers from'
' bob.bio.video. You need to install bob.bio.video as well.')
help="If provided, the predictions will be written in FrameContainers from"
" bob.bio.video. You need to install bob.bio.video as well.",
)
@verbosity_option(cls=ResourceOption)
def predict(estimator, predict_input_fn, output_dir, predict_keys,
checkpoint_path, hooks, video_container, **kwargs):
def predict(
estimator,
predict_input_fn,
output_dir,
predict_keys,
checkpoint_path,
hooks,
video_container,
**kwargs
):
generic_predict(
estimator, predict_input_fn, output_dir, predict_keys,
checkpoint_path, hooks, video_container)
estimator=estimator,
predict_input_fn=predict_input_fn,
output_dir=output_dir,
predict_keys=predict_keys,
checkpoint_path=checkpoint_path,
hooks=hooks,
video_container=video_container,
)
def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None,
checkpoint_path=None, hooks=None, video_container=False):
def generic_predict(
estimator,
predict_input_fn,
output_dir,
predict_keys=None,
checkpoint_path=None,
hooks=None,
video_container=False,
):
# if the checkpoint_path is a directory, pick the latest checkpoint from
# that directory
if checkpoint_path:
......@@ -333,16 +387,21 @@ def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None,
import bob.bio.video
except ModuleNotFoundError:
raise click.ClickException(
'Could not import bob.bio.video. Have you installed it?')
"Could not import bob.bio.video. Have you installed it?"
)
pred_buffer = defaultdict(list)
for i, pred in enumerate(predictions):
key = pred['key']
key = pred["key"]
# key is in bytes format in Python 3
if sys.version_info >= (3, ):
key = key.decode(errors='replace')
prob = pred.get('probabilities', pred.get(
'embeddings', pred.get('predictions')))
if sys.version_info >= (3,):
key = key.decode(errors="replace")
if predict_keys:
prob = pred[predict_keys[0]]
else:
prob = pred.get(
"probabilities", pred.get("embeddings", pred.get("predictions"))
)
assert prob is not None
pred_buffer[key].append(prob)
if i == 0:
......@@ -350,13 +409,17 @@ def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None,
if last_key == key:
continue
else:
save_predictions(
output_dir, last_key, pred_buffer, video_container)
save_predictions(output_dir, last_key, pred_buffer, video_container)
# delete saved data so we don't run out of RAM
del pred_buffer[last_key]
# start saving this new key
last_key = key
try:
key
# save the final returned key as well:
save_predictions(output_dir, key, pred_buffer, video_container)
except UnboundLocalError:
# if the input_fn was empty and hence key is not defined
click.echo("predict_input_fn returned no samples.")
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment