Commit 12964af5 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Handle folder checkpoints

parent 55d973ee
Pipeline #20493 passed with stage
in 41 minutes and 26 seconds
...@@ -13,6 +13,7 @@ from bob.extension.scripts.click_helper import ( ...@@ -13,6 +13,7 @@ from bob.extension.scripts.click_helper import (
from multiprocessing import Pool from multiprocessing import Pool
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
import tensorflow as tf
from bob.io.base import create_directories_safe from bob.io.base import create_directories_safe
from bob.bio.base.utils import save from bob.bio.base.utils import save
from bob.bio.base.tools.grid import indices from bob.bio.base.tools.grid import indices
...@@ -121,7 +122,9 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, ...@@ -121,7 +122,9 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
`None`, returns all. `None`, returns all.
checkpoint_path : str, optional checkpoint_path : str, optional
Path of a specific checkpoint to predict. If `None`, the latest Path of a specific checkpoint to predict. If `None`, the latest
checkpoint in `model_dir` is used. checkpoint in `model_dir` is used. This can also be a folder which
contains a "checkpoint" file where the latest checkpoint from inside
this file will be used as checkpoint_path.
multiple_samples : bool, optional multiple_samples : bool, optional
If provided, it assumes that the db interface returns several samples If provided, it assumes that the db interface returns several samples
from a biofile. This option can be used when you are working with from a biofile. This option can be used when you are working with
...@@ -216,6 +219,11 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, ...@@ -216,6 +219,11 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
generator.output_shapes) generator.output_shapes)
if checkpoint_path: if checkpoint_path:
if os.path.isdir(checkpoint_path):
ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
if ckpt and ckpt.model_checkpoint_path:
checkpoint_path = ckpt.model_checkpoint_path
logger.info("Restoring the model from %s", checkpoint_path) logger.info("Restoring the model from %s", checkpoint_path)
predictions = estimator.predict( predictions = estimator.predict(
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment