train_siamese_casia_webface.py 3.9 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 


"""
Simple script that trains CASIA WEBFACE

Usage:
  train_siamese_casia_webface.py [--batch-size=<arg> --validation-batch-size=<arg> --iterations=<arg> --validation-interval=<arg> --use-gpu]
  train_siamese_casia_webface.py -h | --help
Options:
  -h --help     Show this screen.
  --batch-size=<arg>  [default: 1]
  --validation-batch-size=<arg>   [default:128]
  --iterations=<arg>  [default: 30000]
  --validation-interval=<arg>  [default: 100]
"""

from docopt import docopt
import tensorflow as tf
from .. import util
SEED = 10
25
from bob.learn.tensorflow.datashuffler import TripletDisk, TripletWithSelectionDisk, TripletWithFastSelectionDisk
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
26
from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra, Dummy
27 28
from bob.learn.tensorflow.trainers import SiameseTrainer, TripletTrainer
from bob.learn.tensorflow.loss import ContrastiveLoss, TripletLoss
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
import numpy


def main():
    args = docopt(__doc__, version='Mnist training with TensorFlow')

    BATCH_SIZE = int(args['--batch-size'])
    VALIDATION_BATCH_SIZE = int(args['--validation-batch-size'])
    ITERATIONS = int(args['--iterations'])
    VALIDATION_TEST = int(args['--validation-interval'])
    USE_GPU = args['--use-gpu']
    perc_train = 0.9

    import bob.db.mobio
    db_mobio = bob.db.mobio.Database()

    import bob.db.casia_webface
    db_casia = bob.db.casia_webface.Database()

    # Preparing train set
49
    train_objects = sorted(db_casia.objects(groups="world"), key=lambda x: x.id)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
50 51
    #train_objects = db.objects(groups="world")
    train_labels = [int(o.client_id) for o in train_objects]
52
    directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA_WEBFACE/casia_webface/preprocessed"
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
53 54 55

    train_file_names = [o.make_path(
        directory=directory,
56
        extension=".hdf5")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
57 58
                        for o in train_objects]

59 60 61 62 63
    #train_data_shuffler = TripletWithSelectionDisk(train_file_names, train_labels,
    #                                               input_shape=[125, 125, 3],
    #                                               batch_size=BATCH_SIZE)

    train_data_shuffler = TripletWithFastSelectionDisk(train_file_names, train_labels,
64
                                                       input_shape=[224, 224, 3],
65 66
                                                       batch_size=BATCH_SIZE)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
67 68

    # Preparing train set
69
    directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA_WEBFACE/mobio/preprocessed"
70
    validation_objects = sorted(db_mobio.objects(protocol="male", groups="dev"), key=lambda x: x.id)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
71 72 73 74 75 76 77
    validation_labels = [o.client_id for o in validation_objects]

    validation_file_names = [o.make_path(
        directory=directory,
        extension=".hdf5")
                             for o in validation_objects]

78
    validation_data_shuffler = TripletDisk(validation_file_names, validation_labels,
79
                                           input_shape=[224, 224, 3],
80
                                           batch_size=VALIDATION_BATCH_SIZE)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
81 82 83 84
    # Preparing the architecture
    # LENET PAPER CHOPRA
    architecture = Chopra(seed=SEED)

85 86 87 88 89 90 91 92
    #loss = ContrastiveLoss(contrastive_margin=50.)
    #optimizer = tf.train.GradientDescentOptimizer(0.00001)
    #trainer = SiameseTrainer(architecture=architecture,
    #                         loss=loss,
    #                         iterations=ITERATIONS,
    #                         snapshot=VALIDATION_TEST,
    #                         optimizer=optimizer)

93
    loss = TripletLoss(margin=1.)
94
    trainer = TripletTrainer(architecture=architecture, loss=loss,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
95
                             iterations=ITERATIONS,
96
                             prefetch=False,
97
                             temp_dir="./LOGS_CASIA/triplet-cnn-fast-selection")
98

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
99 100 101

    trainer.train(train_data_shuffler, validation_data_shuffler)
    #trainer.train(train_data_shuffler)