train_mnist_siamese.py 3.28 KB
Newer Older
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
1 2 3 4 5 6 7 8 9 10
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 


"""
Simple script that trains MNIST with LENET using Tensor flow

Usage:
11
  train_mnist_siamese.py [--batch-size=<arg> --validation-batch-size=<arg> --iterations=<arg> --validation-interval=<arg> --use-gpu]
12
  train_mnist_siamese.py -h | --help
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
13 14 15
Options:
  -h --help     Show this screen.
  --batch-size=<arg>  [default: 1]
16
  --validation-batch-size=<arg>   [default:128]
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
17 18 19 20 21 22 23 24
  --iterations=<arg>  [default: 30000]
  --validation-interval=<arg>  [default: 100]
"""

from docopt import docopt
import tensorflow as tf
from .. import util
SEED = 10
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
25
from bob.learn.tensorflow.datashuffler import SiameseMemory
26
from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra, Dummy
27 28 29
from bob.learn.tensorflow.trainers import SiameseTrainer
from bob.learn.tensorflow.loss import ContrastiveLoss
import numpy
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
30

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
31

Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
32 33 34 35
def main():
    args = docopt(__doc__, version='Mnist training with TensorFlow')

    BATCH_SIZE = int(args['--batch-size'])
36
    VALIDATION_BATCH_SIZE = int(args['--validation-batch-size'])
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
37 38 39
    ITERATIONS = int(args['--iterations'])
    VALIDATION_TEST = int(args['--validation-interval'])
    USE_GPU = args['--use-gpu']
40
    perc_train = 0.9
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
41

42
    # Loading data
43

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
44 45 46 47
    train_data, train_labels, validation_data, validation_labels = \
        util.load_mnist(data_dir="./src/bob.db.mnist/bob/db/mnist/")
    train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
    validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
48

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
49 50 51 52
    train_data_shuffler = SiameseMemory(train_data, train_labels,
                                        input_shape=[28, 28, 1],
                                        scale=True,
                                        batch_size=BATCH_SIZE)
53

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
54 55 56 57
    validation_data_shuffler = SiameseMemory(validation_data, validation_labels,
                                             input_shape=[28, 28, 1],
                                             scale=True,
                                             batch_size=VALIDATION_BATCH_SIZE)
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
58

59
    # Preparing the architecture
60
    n_classes = len(train_data_shuffler.possible_labels)
61 62
    cnn = True
    if cnn:
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
63

64
        # LENET PAPER CHOPRA
65
        architecture = Chopra(seed=SEED, fc1_output=n_classes)
66

67
        loss = ContrastiveLoss(contrastive_margin=4.)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
68
        #optimizer = tf.train.GradientDescentOptimizer(0.000001)
69
        trainer = SiameseTrainer(architecture=architecture,
70 71
                                 loss=loss,
                                 iterations=ITERATIONS,
72
                                 snapshot=VALIDATION_TEST,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
73
                                 prefetch=False,
74
                                 temp_dir="./LOGS/siamese-cnn-prefetch")
75

76 77 78 79 80 81 82 83
        trainer.train(train_data_shuffler, validation_data_shuffler)
    else:
        mlp = MLP(n_classes, hidden_layers=[15, 20])

        loss = ContrastiveLoss()
        trainer = SiameseTrainer(architecture=mlp,
                                 loss=loss,
                                 iterations=ITERATIONS,
84 85
                                 snapshot=VALIDATION_TEST,
                                 temp_dir="./LOGS/siamese-dnn")
86
        trainer.train(train_data_shuffler, validation_data_shuffler)