From 12964af50ced76f467fe1b4d0dd3986d4ebbb584 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Thu, 24 May 2018 17:11:55 +0200 Subject: [PATCH] Handle folder checkpoints --- bob/learn/tensorflow/script/predict_bio.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/bob/learn/tensorflow/script/predict_bio.py b/bob/learn/tensorflow/script/predict_bio.py index 50097d23..c4b39634 100644 --- a/bob/learn/tensorflow/script/predict_bio.py +++ b/bob/learn/tensorflow/script/predict_bio.py @@ -13,6 +13,7 @@ from bob.extension.scripts.click_helper import ( from multiprocessing import Pool from collections import defaultdict import numpy as np +import tensorflow as tf from bob.io.base import create_directories_safe from bob.bio.base.utils import save from bob.bio.base.tools.grid import indices @@ -121,7 +122,9 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, `None`, returns all. checkpoint_path : str, optional Path of a specific checkpoint to predict. If `None`, the latest - checkpoint in `model_dir` is used. + checkpoint in `model_dir` is used. This can also be a folder which + contains a "checkpoint" file where the latest checkpoint from inside + this file will be used as checkpoint_path. multiple_samples : bool, optional If provided, it assumes that the db interface returns several samples from a biofile. This option can be used when you are working with @@ -216,6 +219,11 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, generator.output_shapes) if checkpoint_path: + if os.path.isdir(checkpoint_path): + ckpt = tf.train.get_checkpoint_state(estimator.model_dir) + if ckpt and ckpt.model_checkpoint_path: + checkpoint_path = ckpt.model_checkpoint_path + logger.info("Restoring the model from %s", checkpoint_path) predictions = estimator.predict( -- GitLab