train_mnist.py 3.51 KB
Newer Older
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
1 2 3 4 5 6 7 8 9 10
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 


"""
Simple script that trains MNIST with LENET using Tensor flow

Usage:
11
  train_mnist.py [--batch-size=<arg> --validation-batch-size=<arg> --iterations=<arg> --validation-interval=<arg> --use-gpu]
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
12 13 14 15
  train_mnist.py -h | --help
Options:
  -h --help     Show this screen.
  --batch-size=<arg>  [default: 1]
16
  --validation-batch-size=<arg>   [default:128]
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
17 18 19 20 21 22 23 24
  --iterations=<arg>  [default: 30000]
  --validation-interval=<arg>  [default: 100]  
"""

from docopt import docopt
import tensorflow as tf
from .. import util
SEED = 10
25
from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory
26
from bob.learn.tensorflow.network import Lenet, MLP, Dummy, Chopra
27
from bob.learn.tensorflow.trainers import Trainer
28
from bob.learn.tensorflow.loss import BaseLoss
29
import bob.io.base
30
from ..analyzers import ExperimentAnalizer, SoftmaxAnalizer
31

32
import numpy
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
33 34 35 36 37

def main():
    args = docopt(__doc__, version='Mnist training with TensorFlow')

    BATCH_SIZE = int(args['--batch-size'])
38
    VALIDATION_BATCH_SIZE = int(args['--validation-batch-size'])
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
39 40 41 42 43
    ITERATIONS = int(args['--iterations'])
    VALIDATION_TEST = int(args['--validation-interval'])
    USE_GPU = args['--use-gpu']
    perc_train = 0.9

44
    mnist = True
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
45

46 47
    train_data, train_labels, validation_data, validation_labels = \
        util.load_mnist(data_dir="./src/bob.db.mnist/bob/db/mnist/")
48

49 50 51 52 53 54 55 56 57 58 59 60
    train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
    validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))


    # Creating datashufflers
    train_data_shuffler = Memory(train_data, train_labels,
                                             input_shape=[28, 28, 1],
                                             batch_size=BATCH_SIZE)

    validation_data_shuffler = Memory(validation_data, validation_labels,
                                                  input_shape=[28, 28, 1],
                                                  batch_size=VALIDATION_BATCH_SIZE)
61

62
    # Preparing the architecture
63
    cnn = True
64
    if cnn:
65
        architecture = Chopra(seed=SEED, fc1_output=10, batch_norm=False)
66
        loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
67 68 69
        trainer = Trainer(architecture=architecture,
                          loss=loss,
                          iterations=ITERATIONS,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
70
                          prefetch=False, temp_dir="./temp/cnn/no-batch-norm")
71 72 73 74

        #prefetch = False, temp_dir = "./temp/cnn/batch-norm-2convs-all-relu")

        trainer.train(train_data_shuffler, validation_data_shuffler)
75
        #trainer.train(train_data_shuffler)
76 77 78
    else:
        mlp = MLP(10, hidden_layers=[15, 20])
        loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
79
        trainer = Trainer(architecture=mlp, loss=loss, iterations=ITERATIONS, temp_dir="./LOGS/dnn")
80
        trainer.train(train_data_shuffler, validation_data_shuffler)
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
81

82
    # Loading
83 84 85
    #test_data_shuffler = Memory(validation_data, validation_labels,
    #                            input_shape=[28, 28, 1],
    #                            batch_size=400)
86

87 88 89
    #with tf.Session() as session:
        #new_net = Chopra(seed=SEED, fc1_output=10)
        #new_net.load(bob.io.base.HDF5File("./temp/cnn/model.hdf5"), shape=[400, 28, 28, 1], session=session)
90

91 92
        #[data, labels] = test_data_shuffler.get_batch()
        #print new_net(data, session)
93 94 95 96 97 98