Commit 12964af5 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Handle folder checkpoints

parent 55d973ee
Pipeline #20493 passed with stage
in 41 minutes and 26 seconds
......@@ -13,6 +13,7 @@ from bob.extension.scripts.click_helper import (
from multiprocessing import Pool
from collections import defaultdict
import numpy as np
import tensorflow as tf
from bob.io.base import create_directories_safe
from bob.bio.base.utils import save
from bob.bio.base.tools.grid import indices
......@@ -121,7 +122,9 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
`None`, returns all.
checkpoint_path : str, optional
Path of a specific checkpoint to predict. If `None`, the latest
checkpoint in `model_dir` is used.
checkpoint in `model_dir` is used. This can also be a folder which
contains a "checkpoint" file where the latest checkpoint from inside
this file will be used as checkpoint_path.
multiple_samples : bool, optional
If provided, it assumes that the db interface returns several samples
from a biofile. This option can be used when you are working with
......@@ -216,6 +219,11 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
generator.output_shapes)
if checkpoint_path:
if os.path.isdir(checkpoint_path):
ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
if ckpt and ckpt.model_checkpoint_path:
checkpoint_path = ckpt.model_checkpoint_path
logger.info("Restoring the model from %s", checkpoint_path)
predictions = estimator.predict(
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment