diff --git a/bob/learn/tensorflow/script/predict_bio.py b/bob/learn/tensorflow/script/predict_bio.py index 50097d2394f83bb1f44776edf7ab3815901cb4cd..c4b396340a27ee1aad19243140c0b23a75cd33da 100644 --- a/bob/learn/tensorflow/script/predict_bio.py +++ b/bob/learn/tensorflow/script/predict_bio.py @@ -13,6 +13,7 @@ from bob.extension.scripts.click_helper import ( from multiprocessing import Pool from collections import defaultdict import numpy as np +import tensorflow as tf from bob.io.base import create_directories_safe from bob.bio.base.utils import save from bob.bio.base.tools.grid import indices @@ -121,7 +122,9 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, `None`, returns all. checkpoint_path : str, optional Path of a specific checkpoint to predict. If `None`, the latest - checkpoint in `model_dir` is used. + checkpoint in `model_dir` is used. This can also be a folder which + contains a "checkpoint" file where the latest checkpoint from inside + this file will be used as checkpoint_path. multiple_samples : bool, optional If provided, it assumes that the db interface returns several samples from a biofile. This option can be used when you are working with @@ -216,6 +219,11 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, generator.output_shapes) if checkpoint_path: + if os.path.isdir(checkpoint_path): + ckpt = tf.train.get_checkpoint_state(estimator.model_dir) + if ckpt and ckpt.model_checkpoint_path: + checkpoint_path = ckpt.model_checkpoint_path + logger.info("Restoring the model from %s", checkpoint_path) predictions = estimator.predict(