test_cnn_scratch.py 3.99 KB
Newer Older
1
2
3
4
5
6
7
8
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 13 Oct 2016 13:35 CEST

import numpy
import bob.io.base
import os
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
9
10
from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, ScaleFactor
from bob.learn.tensorflow.network import Embedding
11
from bob.learn.tensorflow.loss import BaseLoss
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
12
from bob.learn.tensorflow.trainers import Trainer, learning_rate
13
from bob.learn.tensorflow.utils import load_mnist
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
14
from bob.learn.tensorflow.layers import Conv2D, FullyConnected
15
16
17
18
19
20
21
22
23
import tensorflow as tf
import shutil

"""
Some unit tests that create networks on the fly
"""

batch_size = 16
validation_batch_size = 400
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
24
iterations = 300
25
seed = 10
26
directory = "./temp/cnn_scratch"
27

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
28
29
slim = tf.contrib.slim

30

31
def scratch_network():
32
33
    # Creating a random network

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
34
35
36
37
38
39
40
41
42
43
44
    inputs = {}
    inputs['data'] = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name="train_data")
    inputs['label'] = tf.placeholder(tf.int64, shape=[None], name="train_label")

    initializer = tf.contrib.layers.xavier_initializer(seed=seed)
    scratch = slim.conv2d(inputs['data'], 10, [3, 3], activation_fn=tf.nn.relu, stride=1, scope='conv1',
                          weights_initializer=initializer)
    scratch = slim.max_pool2d(scratch, [4, 4], scope='pool1')
    scratch = slim.flatten(scratch, scope='flatten1')
    scratch = slim.fully_connected(scratch, 10, activation_fn=None, scope='fc1',
                                   weights_initializer=initializer)
45

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
46
    return inputs, scratch
47

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
48
def validate_network(embedding, validation_data, validation_labels):
49
50
51
    # Testing
    validation_data_shuffler = Memory(validation_data, validation_labels,
                                      input_shape=[28, 28, 1],
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
52
53
                                      batch_size=validation_batch_size,
                                      normalizer=ScaleFactor())
54
55

    [data, labels] = validation_data_shuffler.get_batch()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
56
57
    predictions = embedding(data)
    accuracy = 100. * numpy.sum(numpy.argmax(predictions, axis=1) == labels) / predictions.shape[0]
58

59
60
61
62
    return accuracy


def test_cnn_trainer_scratch():
63

64
65
66
67
68
69
70
71
    train_data, train_labels, validation_data, validation_labels = load_mnist()
    train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))

    # Creating datashufflers
    data_augmentation = ImageAugmentation()
    train_data_shuffler = Memory(train_data, train_labels,
                                 input_shape=[28, 28, 1],
                                 batch_size=batch_size,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
72
73
74
75
76
77
78
                                 data_augmentation=data_augmentation,
                                 normalizer=ScaleFactor())
    validation_data_shuffler = Memory(train_data, train_labels,
                                 input_shape=[28, 28, 1],
                                 batch_size=batch_size,
                                 data_augmentation=data_augmentation,
                                 normalizer=ScaleFactor())
79
80
    validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
81

82
    # Create scratch network
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
83
84
    inputs, scratch = scratch_network()
    embedding = Embedding(inputs['data'], scratch)
85
86
87
88
89

    # Loss for the softmax
    loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)

    # One graph trainer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
90
91
    trainer = Trainer(inputs=inputs,
                      graph=scratch,
92
                      iterations=iterations,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
93
                      loss=loss,
94
95
                      analizer=None,
                      prefetch=False,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
96
97
98
99
                      temp_dir=directory,
                      optimizer=tf.train.GradientDescentOptimizer(0.01),
                      learning_rate=learning_rate.constant(base_learning_rate=0.01, name="constant_learning_rate"),
                      validation_snapshot=20
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
100
                      )
101

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
102
    trainer.train(train_data_shuffler, validation_data_shuffler)
103

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
104
    accuracy = validate_network(embedding, validation_data, validation_labels)
105
    assert accuracy > 80
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
106
    shutil.rmtree(directory)
107
    del trainer