test_cnn_scratch.py 3.72 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 13 Oct 2016 13:35 CEST

import numpy
import bob.io.base
import os
from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation
from bob.learn.tensorflow.initialization import Xavier, Constant
from bob.learn.tensorflow.network import SequenceNetwork
from bob.learn.tensorflow.loss import BaseLoss
from bob.learn.tensorflow.trainers import Trainer
from bob.learn.tensorflow.util import load_mnist
from bob.learn.tensorflow.layers import Conv2D, FullyConnected, MaxPooling
import tensorflow as tf
import shutil

"""
Some unit tests that create networks on the fly
"""

batch_size = 16
validation_batch_size = 400
iterations = 50
seed = 10


29
def scratch_network():
30
    # Creating a random network
31
    scratch = SequenceNetwork(default_feature_layer="fc1")
32 33 34 35
    scratch.add(Conv2D(name="conv1", kernel_size=3,
                       filters=10,
                       activation=tf.nn.tanh,
                       weights_initialization=Xavier(seed=seed, use_gpu=False),
36 37
                       bias_initialization=Constant(use_gpu=False),
                       batch_norm=True))
38 39 40
    scratch.add(FullyConnected(name="fc1", output_dim=10,
                               activation=None,
                               weights_initialization=Xavier(seed=seed, use_gpu=False),
41 42
                               bias_initialization=Constant(use_gpu=False)
                               ))
43

44
    return scratch
45 46


47
def validate_network(validation_data, validation_labels, directory):
48 49 50 51 52 53
    # Testing
    validation_data_shuffler = Memory(validation_data, validation_labels,
                                      input_shape=[28, 28, 1],
                                      batch_size=validation_batch_size)
    with tf.Session() as session:
        validation_shape = [400, 28, 28, 1]
54 55 56
        path = os.path.join(directory, "model.hdf5")
        #path = os.path.join(directory, "model.ckp")
        #scratch = SequenceNetwork(default_feature_layer="fc1")
57
        scratch = SequenceNetwork()
58 59
        #scratch.load_original(session, os.path.join(directory, "model.ckp"))
        scratch.load(bob.io.base.HDF5File(path),
60
                     shape=validation_shape, session=session)
61

62 63 64 65
        [data, labels] = validation_data_shuffler.get_batch()
        predictions = scratch(data, session=session)
        accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == labels) / predictions.shape[0]

66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
    return accuracy


def test_cnn_trainer_scratch():
    train_data, train_labels, validation_data, validation_labels = load_mnist()
    train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))

    # Creating datashufflers
    data_augmentation = ImageAugmentation()
    train_data_shuffler = Memory(train_data, train_labels,
                                 input_shape=[28, 28, 1],
                                 batch_size=batch_size,
                                 data_augmentation=data_augmentation)
    validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))

    directory = "./temp/cnn"

    # Create scratch network
    scratch = scratch_network()

    # Loss for the softmax
    loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)

    # One graph trainer
    trainer = Trainer(architecture=scratch,
                      loss=loss,
                      iterations=iterations,
                      analizer=None,
                      prefetch=False,
                      temp_dir=directory)
    trainer.train(train_data_shuffler)

98 99
    import ipdb; ipdb.set_trace();

100
    accuracy = validate_network(validation_data, validation_labels, directory)
101

102 103 104
    assert accuracy > 80
    del scratch

105