From 12964af50ced76f467fe1b4d0dd3986d4ebbb584 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Thu, 24 May 2018 17:11:55 +0200
Subject: [PATCH] Handle folder checkpoints

---
 bob/learn/tensorflow/script/predict_bio.py | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/bob/learn/tensorflow/script/predict_bio.py b/bob/learn/tensorflow/script/predict_bio.py
index 50097d23..c4b39634 100644
--- a/bob/learn/tensorflow/script/predict_bio.py
+++ b/bob/learn/tensorflow/script/predict_bio.py
@@ -13,6 +13,7 @@ from bob.extension.scripts.click_helper import (
 from multiprocessing import Pool
 from collections import defaultdict
 import numpy as np
+import tensorflow as tf
 from bob.io.base import create_directories_safe
 from bob.bio.base.utils import save
 from bob.bio.base.tools.grid import indices
@@ -121,7 +122,9 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
         `None`, returns all.
     checkpoint_path : str, optional
         Path of a specific checkpoint to predict. If `None`, the latest
-        checkpoint in `model_dir` is used.
+        checkpoint in `model_dir` is used. This can also be a folder which
+        contains a "checkpoint" file where the latest checkpoint from inside
+        this file will be used as checkpoint_path.
     multiple_samples : bool, optional
         If provided, it assumes that the db interface returns several samples
         from a biofile. This option can be used when you are working with
@@ -216,6 +219,11 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
                                             generator.output_shapes)
 
     if checkpoint_path:
+        if os.path.isdir(checkpoint_path):
+            ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
+            if ckpt and ckpt.model_checkpoint_path:
+                checkpoint_path = ckpt.model_checkpoint_path
+
         logger.info("Restoring the model from %s", checkpoint_path)
 
     predictions = estimator.predict(
-- 
GitLab