train_mnist_siamese.py 4.1 KB
Newer Older
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
1 2 3 4 5 6 7 8 9 10
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 


"""
Simple script that trains MNIST with LENET using Tensor flow

Usage:
11
  train_mnist_siamese.py [--batch-size=<arg> --validation-batch-size=<arg> --iterations=<arg> --validation-interval=<arg> --use-gpu]
12
  train_mnist_siamese.py -h | --help
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
13 14 15
Options:
  -h --help     Show this screen.
  --batch-size=<arg>  [default: 1]
16
  --validation-batch-size=<arg>   [default:128]
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
17 18 19 20 21 22 23 24
  --iterations=<arg>  [default: 30000]
  --validation-interval=<arg>  [default: 100]
"""

from docopt import docopt
import tensorflow as tf
from .. import util
SEED = 10
25
from bob.learn.tensorflow.data import MemoryDataShuffler, TextDataShuffler
26 27 28
from bob.learn.tensorflow.network import Lenet
from bob.learn.tensorflow.trainers import SiameseTrainer
from bob.learn.tensorflow.loss import ContrastiveLoss
29
import bob.db.mobio
30
import numpy
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
31 32 33 34 35

def main():
    args = docopt(__doc__, version='Mnist training with TensorFlow')

    BATCH_SIZE = int(args['--batch-size'])
36
    VALIDATION_BATCH_SIZE = int(args['--validation-batch-size'])
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
37 38 39
    ITERATIONS = int(args['--iterations'])
    VALIDATION_TEST = int(args['--validation-interval'])
    USE_GPU = args['--use-gpu']
40
    perc_train = 0.9
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
41

42
    # Loading data
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    mnist = False

    if mnist:
        train_data, train_labels, validation_data, validation_labels = \
            util.load_mnist(data_dir="./src/bob.db.mnist/bob/db/mnist/")
        train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
        validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))

        train_data_shuffler = MemoryDataShuffler(train_data, train_labels,
                                                 input_shape=[28, 28, 1],
                                                 scale=True,
                                                 batch_size=BATCH_SIZE)

        validation_data_shuffler = MemoryDataShuffler(validation_data, validation_labels,
                                                      input_shape=[28, 28, 1],
                                                      scale=True,
                                                      batch_size=VALIDATION_BATCH_SIZE)

    else:

        import bob.db.mobio
        db = bob.db.mobio.Database()

        # Preparing train set
        train_objects = db.objects(protocol="male", groups="world")
        train_labels = [o.client_id for o in train_objects]
        train_file_names = [o.make_path(
            directory="/remote/lustre/2/temp/tpereira/FACEREC_EXPERIMENTS/mobio_male/lda/preprocessed",
            extension=".hdf5")
                            for o in train_objects]

        train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
                                               input_shape=[80, 64, 1],
                                               batch_size=BATCH_SIZE)

        # Preparing train set
        validation_objects = db.objects(protocol="male", groups="dev")
        validation_labels = [o.client_id for o in validation_objects]
        validation_file_names = [o.make_path(
            directory="/remote/lustre/2/temp/tpereira/FACEREC_EXPERIMENTS/mobio_male/lda/preprocessed",
            extension=".hdf5")
                                 for o in validation_objects]

        validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
                                                    input_shape=[80, 64, 1],
                                                    batch_size=VALIDATION_BATCH_SIZE)
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
89

90
    # Preparing the architecture
91 92
    n_classes = len(train_data_shuffler.possible_labels)
    lenet = Lenet(default_feature_layer="fc2", n_classes=n_classes)
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
93

94
    loss = ContrastiveLoss()
95 96 97
    trainer = SiameseTrainer(architecture=lenet,
                             loss=loss,
                             iterations=ITERATIONS,
98
                             base_lr=0.0001,
99 100
                             save_intermediate=False,
                             snapshot=VALIDATION_TEST)
101
    trainer.train(train_data_shuffler, validation_data_shuffler)
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
102 103