Skip to content
Snippets Groups Projects
Commit 12964af5 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Handle folder checkpoints

parent 55d973ee
No related branches found
No related tags found
1 merge request!52Implement model saving in bob tf eval. Fixes #54
Pipeline #
...@@ -13,6 +13,7 @@ from bob.extension.scripts.click_helper import ( ...@@ -13,6 +13,7 @@ from bob.extension.scripts.click_helper import (
from multiprocessing import Pool from multiprocessing import Pool
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
import tensorflow as tf
from bob.io.base import create_directories_safe from bob.io.base import create_directories_safe
from bob.bio.base.utils import save from bob.bio.base.utils import save
from bob.bio.base.tools.grid import indices from bob.bio.base.tools.grid import indices
...@@ -121,7 +122,9 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, ...@@ -121,7 +122,9 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
`None`, returns all. `None`, returns all.
checkpoint_path : str, optional checkpoint_path : str, optional
Path of a specific checkpoint to predict. If `None`, the latest Path of a specific checkpoint to predict. If `None`, the latest
checkpoint in `model_dir` is used. checkpoint in `model_dir` is used. This can also be a folder which
contains a "checkpoint" file where the latest checkpoint from inside
this file will be used as checkpoint_path.
multiple_samples : bool, optional multiple_samples : bool, optional
If provided, it assumes that the db interface returns several samples If provided, it assumes that the db interface returns several samples
from a biofile. This option can be used when you are working with from a biofile. This option can be used when you are working with
...@@ -216,6 +219,11 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, ...@@ -216,6 +219,11 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
generator.output_shapes) generator.output_shapes)
if checkpoint_path: if checkpoint_path:
if os.path.isdir(checkpoint_path):
ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
if ckpt and ckpt.model_checkpoint_path:
checkpoint_path = ckpt.model_checkpoint_path
logger.info("Restoring the model from %s", checkpoint_path) logger.info("Restoring the model from %s", checkpoint_path)
predictions = estimator.predict( predictions = estimator.predict(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment