test_estimator_scripts.py 4.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from __future__ import print_function
import os
from tempfile import mkdtemp
import shutil
import logging
logging.getLogger("tensorflow").setLevel(logging.WARNING)
from bob.io.base.test_utils import datafile

from bob.learn.tensorflow.script.db_to_tfrecords import main as tfrecords
from bob.learn.tensorflow.script.train_generic import main as train_generic
from bob.learn.tensorflow.script.eval_generic import main as eval_generic

dummy_tfrecord_config = datafile('dummy_verify_config.py', __name__)
CONFIG = '''
import tensorflow as tf
16
from bob.learn.tensorflow.utils.reproducible import run_config
17
from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, \
18
19
20
21
22
23
24
    batch_data_and_labels

model_dir = "%(model_dir)s"
tfrecord_filenames = ['%(tfrecord_filenames)s']
data_shape = (1, 112, 92)  # size of atnt images
data_type = tf.uint8
batch_size = 2
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
25
epochs = 2
26
27
28
29
30
31
32
33
34
35
learning_rate = 0.00001
run_once = True


def train_input_fn():
    return shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type,
                                   batch_size, epochs=epochs)

def eval_input_fn():
    return batch_data_and_labels(tfrecord_filenames, data_shape, data_type,
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
36
                                   batch_size, epochs=1)
37
38
39
40
41
42
43
44
45
46

def architecture(images):
    images = tf.cast(images, tf.float32)
    logits = tf.reshape(images, [-1, 92 * 112])
    logits = tf.layers.dense(inputs=logits, units=20,
                             activation=tf.nn.relu)
    return logits


def model_fn(features, labels, mode, params, config):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
47
48
49
    key = features['key']
    features = features['data']

50
51
52
53
54
55
56
    logits = architecture(features)

    predictions = {
        # Generate predictions (for PREDICT and EVAL mode)
        "classes": tf.argmax(input=logits, axis=1),
        # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
        # `logging_hook`.
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
57
58
        "probabilities": tf.nn.softmax(logits, name="softmax_tensor"),
        "key": key,
59
60
61
62
63
    }
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Calculate Loss (for both TRAIN and EVAL modes)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
64
    loss = tf.losses.sparse_softmax_cross_entropy(
65
        logits=logits, labels=labels)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
66
67
68
    accuracy = tf.metrics.accuracy(
        labels=labels, predictions=predictions["classes"])
    metrics = {'accuracy': accuracy}
69

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
70
    # Configure the training op
71
    if mode == tf.estimator.ModeKeys.TRAIN:
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
72
73
74
75
76
77
78
79
80
81
82
        optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=learning_rate)
        train_op = optimizer.minimize(
            loss=loss, global_step=tf.train.get_or_create_global_step())
        # Log accuracy and loss
        with tf.name_scope('train_metrics'):
            tf.summary.scalar('accuracy', accuracy[1])
            tf.summary.scalar('loss', loss)
    else:
        train_op = None

83
    return tf.estimator.EstimatorSpec(
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
84
85
86
87
88
89
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=metrics)

90
91
92

estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
                                   config=run_config)
93
94
95
96
97
98
99
100
'''


def _create_tfrecord(test_dir):
    config_path = os.path.join(test_dir, 'tfrecordconfig.py')
    with open(dummy_tfrecord_config) as f, open(config_path, 'w') as f2:
        f2.write(f.read().replace('TEST_DIR', test_dir))
    tfrecords([config_path])
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
101
    return os.path.join(test_dir, 'dev.tfrecords')
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121


def _create_checkpoint(tmpdir, model_dir, dummy_tfrecord):
    config = CONFIG % {'model_dir': model_dir,
                       'tfrecord_filenames': dummy_tfrecord}
    config_path = os.path.join(tmpdir, 'train_config.py')
    with open(config_path, 'w') as f:
        f.write(config)
    train_generic([config_path])


def _eval(tmpdir, model_dir, dummy_tfrecord):
    config = CONFIG % {'model_dir': model_dir,
                       'tfrecord_filenames': dummy_tfrecord}
    config_path = os.path.join(tmpdir, 'eval_config.py')
    with open(config_path, 'w') as f:
        f.write(config)
    eval_generic([config_path])


Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
122
def test_eval():
123
124
125
126
127
128
    tmpdir = mkdtemp(prefix='bob_')
    try:
        model_dir = os.path.join(tmpdir, 'model_dir')
        eval_dir = os.path.join(model_dir, 'eval')

        print('\nCreating a dummy tfrecord')
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
129
        dummy_tfrecord = _create_tfrecord(tmpdir)
130
131

        print('Training a dummy network')
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
132
        _create_checkpoint(tmpdir, model_dir, dummy_tfrecord)
133
134

        print('Evaluating a dummy network')
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
135
        _eval(tmpdir, model_dir, dummy_tfrecord)
136

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
137
138
139
        evaluated_path = os.path.join(eval_dir, 'evaluated')
        assert os.path.exists(evaluated_path), evaluated_path
        with open(evaluated_path) as f:
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
140
            doc = f.read()
141

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
142
        assert '1' in doc, doc
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
143
        assert '200' in doc, doc
144
145
146
147
148
    finally:
        try:
            shutil.rmtree(tmpdir)
        except Exception:
            pass