Skip to content
Snippets Groups Projects

Resolve "Adopt to the Estimators API"

Merged Tiago de Freitas Pereira requested to merge 40-adopt-to-the-estimators-api into master
9 files
+ 423
124
Compare changes
  • Side-by-side
  • Inline

Files

@@ -17,16 +17,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from bob.learn.tensorflow.utils.reproducible import session_conf
# by default create reproducible nets:
from bob.learn.tensorflow.utils.reproducible import run_config
# utils.reproducible import run_config
import tensorflow as tf
from bob.db.mnist import Database
model_dir = '/tmp/mnist_model'
train_tfrecords = ['/tmp/mnist_data/train.tfrecords']
eval_tfrecords = ['/tmp/mnist_data/test.tfrecords']
# by default create reproducible nets:
run_config = tf.estimator.RunConfig()
run_config = run_config.replace(session_config=session_conf)
run_config = run_config.replace(keep_checkpoint_max=10**3)
run_config = run_config.replace(save_checkpoints_secs=60)
@@ -39,22 +39,27 @@ def input_fn(mode, batch_size=1):
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'data': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'key': tf.FixedLenFeature([], tf.string),
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
image = tf.decode_raw(features['data'], tf.uint8)
image.set_shape([28 * 28])
# Normalize the values of the image from the range
# [0, 255] to [-0.5, 0.5]
image = tf.cast(image, tf.float32) / 255 - 0.5
label = tf.cast(features['label'], tf.int32)
return image, tf.one_hot(label, 10)
key = tf.cast(features['key'], tf.string)
return image, tf.one_hot(label, 10), key
if mode == tf.estimator.ModeKeys.TRAIN:
tfrecords_files = train_tfrecords
elif mode == tf.estimator.ModeKeys.EVAL:
tfrecords_files = eval_tfrecords
else:
assert mode == tf.estimator.ModeKeys.EVAL, 'invalid mode'
assert mode == tf.estimator.ModeKeys.PREDICT, 'invalid mode'
tfrecords_files = eval_tfrecords
for tfrecords_file in tfrecords_files:
@@ -73,9 +78,9 @@ def input_fn(mode, batch_size=1):
dataset = dataset.map(
example_parser, num_threads=1, output_buffer_size=batch_size)
dataset = dataset.batch(batch_size)
images, labels = dataset.make_one_shot_iterator().get_next()
images, labels, keys = dataset.make_one_shot_iterator().get_next()
return images, labels
return {'images': images, 'keys': keys}, labels
def train_input_fn():
@@ -86,6 +91,10 @@ def eval_input_fn():
return input_fn(tf.estimator.ModeKeys.EVAL)
def predict_input_fn():
return input_fn(tf.estimator.ModeKeys.PREDICT)
def mnist_model(inputs, mode):
"""Takes the MNIST inputs and mode and outputs a tensor of logits."""
# Input Layer
@@ -164,13 +173,16 @@ def mnist_model(inputs, mode):
return logits
def model_fn(features, labels, mode):
def model_fn(features, labels=None, mode=tf.estimator.ModeKeys.TRAIN):
"""Model function for MNIST."""
keys = features['keys']
features = features['images']
logits = mnist_model(features, mode)
predictions = {
'classes': tf.argmax(input=logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
'probabilities': tf.nn.softmax(logits, name='softmax_tensor'),
'keys': keys,
}
if mode == tf.estimator.ModeKeys.PREDICT:
@@ -202,3 +214,22 @@ def model_fn(features, labels, mode):
loss=loss,
train_op=train_op,
eval_metric_ops=metrics)
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
params=None, config=run_config)
output = train_tfrecords[0]
db = Database()
data, labels = db.data(groups='train')
# output = eval_tfrecords[0]
# db = Database()
# data, labels = db.data(groups='test')
samples = zip(data, labels, (str(i) for i in range(len(data))))
def reader(sample):
return sample
Loading