train_mnist.py 2.33 KB
Newer Older
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 


"""
Simple script that trains MNIST with LENET using Tensor flow

Usage:
  train_mnist.py [--batch-size=<arg> --iterations=<arg> --validation-interval=<arg> --use-gpu]
  train_mnist.py -h | --help
Options:
  -h --help     Show this screen.
  --batch-size=<arg>  [default: 1]
  --iterations=<arg>  [default: 30000]
  --validation-interval=<arg>  [default: 100]  
"""

from docopt import docopt
import tensorflow as tf
from .. import util
SEED = 10
24
from bob.learn.tensorflow.data import MemoryDataShuffler, TextDataShuffler
25 26
from bob.learn.tensorflow.network import Lenet
from bob.learn.tensorflow.trainers import Trainer
27
from bob.learn.tensorflow.loss import BaseLoss
28
import bob.db.mobio
29

30
import numpy
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
31 32 33 34 35 36 37 38 39 40 41

def main():
    args = docopt(__doc__, version='Mnist training with TensorFlow')

    BATCH_SIZE = int(args['--batch-size'])
    ITERATIONS = int(args['--iterations'])
    VALIDATION_TEST = int(args['--validation-interval'])
    USE_GPU = args['--use-gpu']
    perc_train = 0.9

    # Loading data
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
    #data, labels = util.load_mnist(data_dir="./src/bob.db.mnist/bob/db/mnist/")
    #data = numpy.reshape(data, (data.shape[0], 28, 28, 1))

    #data_shuffler = MemoryDataShuffler(data, labels,
    #                                   input_shape=[28, 28, 1],
    #                                   train_batch_size=BATCH_SIZE,
    #                                   validation_batch_size=BATCH_SIZE*100)


    db = bob.db.mobio.Database()
    objects = db.objects(protocol="male")

    labels = [o.client_id for o in objects]
    file_names = [o.make_path(
        directory="/remote/lustre/2/temp/tpereira/FACEREC_EXPERIMENTS/mobio_male/lda/preprocessed",
        extension=".hdf5")
                  for o in objects]

    data_shuffler = TextDataShuffler(file_names, labels,
                                     input_shape=[80, 64, 1],
                                     train_batch_size=BATCH_SIZE,
                                     validation_batch_size=BATCH_SIZE*100)
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
64

65 66
    # Preparing the architecture
    lenet = Lenet()
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
67

68
    loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
69
    trainer = Trainer(architecture=lenet, loss=loss, iterations=ITERATIONS)
70
    trainer.train(data_shuffler)
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
71 72