test_cnn_scratch.py 2.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 13 Oct 2016 13:35 CEST

import numpy
import bob.io.base
import os
from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation
from bob.learn.tensorflow.network import SequenceNetwork
from bob.learn.tensorflow.loss import BaseLoss
from bob.learn.tensorflow.trainers import Trainer
13
from bob.learn.tensorflow.utils import load_mnist
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
14
from bob.learn.tensorflow.layers import Conv2D, FullyConnected
15
16
17
18
19
20
21
22
23
24
25
import tensorflow as tf
import shutil

"""
Some unit tests that create networks on the fly
"""

batch_size = 16
validation_batch_size = 400
iterations = 50
seed = 10
26
directory = "./temp/cnn_scratch"
27
28


29
def scratch_network():
30
    # Creating a random network
31
    scratch = SequenceNetwork(default_feature_layer="fc1")
32
33
34
    scratch.add(Conv2D(name="conv1", kernel_size=3,
                       filters=10,
                       activation=tf.nn.tanh,
35
                       batch_norm=False))
36
37
    scratch.add(FullyConnected(name="fc1", output_dim=10,
                               activation=None,
38
                               batch_norm=False
39
                               ))
40

41
    return scratch
42
43


44
def validate_network(validation_data, validation_labels, network):
45
46
47
48
    # Testing
    validation_data_shuffler = Memory(validation_data, validation_labels,
                                      input_shape=[28, 28, 1],
                                      batch_size=validation_batch_size)
49
50
51
52

    [data, labels] = validation_data_shuffler.get_batch()
    predictions = network.predict(data)
    accuracy = 100. * numpy.sum(predictions == labels) / predictions.shape[0]
53

54
55
56
57
    return accuracy


def test_cnn_trainer_scratch():
58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    train_data, train_labels, validation_data, validation_labels = load_mnist()
    train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))

    # Creating datashufflers
    data_augmentation = ImageAugmentation()
    train_data_shuffler = Memory(train_data, train_labels,
                                 input_shape=[28, 28, 1],
                                 batch_size=batch_size,
                                 data_augmentation=data_augmentation)
    validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))

    # Create scratch network
    scratch = scratch_network()

    # Loss for the softmax
    loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)

    # One graph trainer
    trainer = Trainer(architecture=scratch,
                      loss=loss,
                      iterations=iterations,
                      analizer=None,
                      prefetch=False,
                      temp_dir=directory)

84
    trainer.train(train_data_shuffler)
85

86
    accuracy = validate_network(validation_data, validation_labels, scratch)
87
    assert accuracy > 80
88
    shutil.rmtree(directory)
89
    del trainer