train_mnist_siamese.py 6 KB
Newer Older
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
1 2 3 4 5 6 7 8 9 10
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 


"""
Simple script that trains MNIST with LENET using Tensor flow

Usage:
11
  train_mnist_siamese.py [--batch-size=<arg> --validation-batch-size=<arg> --iterations=<arg> --validation-interval=<arg> --use-gpu]
12
  train_mnist_siamese.py -h | --help
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
13 14 15
Options:
  -h --help     Show this screen.
  --batch-size=<arg>  [default: 1]
16
  --validation-batch-size=<arg>   [default:128]
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
17 18 19 20 21 22 23 24
  --iterations=<arg>  [default: 30000]
  --validation-interval=<arg>  [default: 100]
"""

from docopt import docopt
import tensorflow as tf
from .. import util
SEED = 10
25
from bob.learn.tensorflow.data import MemoryDataShuffler, TextDataShuffler
26
from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra, Dummy
27 28 29
from bob.learn.tensorflow.trainers import SiameseTrainer
from bob.learn.tensorflow.loss import ContrastiveLoss
import numpy
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
30 31 32 33 34

def main():
    args = docopt(__doc__, version='Mnist training with TensorFlow')

    BATCH_SIZE = int(args['--batch-size'])
35
    VALIDATION_BATCH_SIZE = int(args['--validation-batch-size'])
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
36 37 38
    ITERATIONS = int(args['--iterations'])
    VALIDATION_TEST = int(args['--validation-interval'])
    USE_GPU = args['--use-gpu']
39
    perc_train = 0.9
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
40

41
    # Loading data
42
    mnist = False
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60

    if mnist:
        train_data, train_labels, validation_data, validation_labels = \
            util.load_mnist(data_dir="./src/bob.db.mnist/bob/db/mnist/")
        train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
        validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))

        train_data_shuffler = MemoryDataShuffler(train_data, train_labels,
                                                 input_shape=[28, 28, 1],
                                                 scale=True,
                                                 batch_size=BATCH_SIZE)

        validation_data_shuffler = MemoryDataShuffler(validation_data, validation_labels,
                                                      input_shape=[28, 28, 1],
                                                      scale=True,
                                                      batch_size=VALIDATION_BATCH_SIZE)

    else:
61 62
        import bob.db.mobio
        db_mobio = bob.db.mobio.Database()
63

64 65
        import bob.db.casia_webface
        db_casia = bob.db.casia_webface.Database()
66 67

        # Preparing train set
68 69 70 71 72
        train_objects = db_casia.objects(groups="world")
        #train_objects = db.objects(groups="world")
        train_labels = [int(o.client_id) for o in train_objects]
        directory = "/idiap/resource/database/CASIA-WebFace/CASIA-WebFace"

73
        train_file_names = [o.make_path(
74 75
            directory=directory,
            extension="")
76
                            for o in train_objects]
77 78 79 80 81 82 83
        #import ipdb;
        #ipdb.set_trace();

        #train_file_names = [o.make_path(
        #    directory="/idiap/group/biometric/databases/orl",
        #    extension=".pgm")
        #                    for o in train_objects]
84 85

        train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
86
                                               input_shape=[250, 250, 3],
87 88
                                               batch_size=BATCH_SIZE)

89 90 91 92
        #train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
        #                                       input_shape=[56, 46, 1],
        #                                       batch_size=BATCH_SIZE)

93
        # Preparing train set
94 95
        directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA/preprocessed"
        validation_objects = db_mobio.objects(protocol="male", groups="dev")
96
        validation_labels = [o.client_id for o in validation_objects]
97 98 99 100 101
        #validation_file_names = [o.make_path(
        #    directory="/idiap/group/biometric/databases/orl",
        #    extension=".pgm")
        #                         for o in validation_objects]

102
        validation_file_names = [o.make_path(
103 104
            directory=directory,
            extension=".hdf5")
105 106 107
                                 for o in validation_objects]

        validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
108
                                                    input_shape=[250, 250, 3],
109
                                                    batch_size=VALIDATION_BATCH_SIZE)
110 111 112
        #validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
        #                                            input_shape=[56, 46, 1],
        #                                            batch_size=VALIDATION_BATCH_SIZE)
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
113

114
    # Preparing the architecture
115
    n_classes = len(train_data_shuffler.possible_labels)
116
    #n_classes = 200
117 118
    cnn = True
    if cnn:
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
119

120 121
        # LENET PAPER CHOPRA
        #architecture = Chopra(default_feature_layer="fc7")
122
        architecture = Lenet(default_feature_layer="fc2", n_classes=n_classes, conv1_output=8, conv2_output=16,use_gpu=USE_GPU)
123
        #architecture = VGG(n_classes=n_classes, use_gpu=USE_GPU)
124
        #architecture = Dummy(seed=SEED)
125 126

        #architecture = LenetDropout(default_feature_layer="fc2", n_classes=n_classes, conv1_output=4, conv2_output=8, use_gpu=USE_GPU)
127 128

        loss = ContrastiveLoss()
129 130
        #optimizer = tf.train.GradientDescentOptimizer(0.0001)
        trainer = SiameseTrainer(architecture=architecture,
131 132
                                 loss=loss,
                                 iterations=ITERATIONS,
133 134
                                 snapshot=VALIDATION_TEST,
                                 )
135 136 137 138 139 140 141 142 143 144
        trainer.train(train_data_shuffler, validation_data_shuffler)
    else:
        mlp = MLP(n_classes, hidden_layers=[15, 20])

        loss = ContrastiveLoss()
        trainer = SiameseTrainer(architecture=mlp,
                                 loss=loss,
                                 iterations=ITERATIONS,
                                 snapshot=VALIDATION_TEST)
        trainer.train(train_data_shuffler, validation_data_shuffler)
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
145