train_siamese_casia_webface.py 3.2 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 


"""
Simple script that trains CASIA WEBFACE

Usage:
  train_siamese_casia_webface.py [--batch-size=<arg> --validation-batch-size=<arg> --iterations=<arg> --validation-interval=<arg> --use-gpu]
  train_siamese_casia_webface.py -h | --help
Options:
  -h --help     Show this screen.
  --batch-size=<arg>  [default: 1]
  --validation-batch-size=<arg>   [default:128]
  --iterations=<arg>  [default: 30000]
  --validation-interval=<arg>  [default: 100]
"""

from docopt import docopt
import tensorflow as tf
from .. import util
SEED = 10
from bob.learn.tensorflow.data import MemoryDataShuffler, TextDataShuffler
from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra, Dummy
from bob.learn.tensorflow.trainers import SiameseTrainer
from bob.learn.tensorflow.loss import ContrastiveLoss
import numpy


def main():
    args = docopt(__doc__, version='Mnist training with TensorFlow')

    BATCH_SIZE = int(args['--batch-size'])
    VALIDATION_BATCH_SIZE = int(args['--validation-batch-size'])
    ITERATIONS = int(args['--iterations'])
    VALIDATION_TEST = int(args['--validation-interval'])
    USE_GPU = args['--use-gpu']
    perc_train = 0.9

    import bob.db.mobio
    db_mobio = bob.db.mobio.Database()

    import bob.db.casia_webface
    db_casia = bob.db.casia_webface.Database()

    # Preparing train set
    train_objects = db_casia.objects(groups="world")
    #train_objects = db.objects(groups="world")
    train_labels = [int(o.client_id) for o in train_objects]
    directory = "/idiap/resource/database/CASIA-WebFace/CASIA-WebFace"

    train_file_names = [o.make_path(
        directory=directory,
        extension="")
                        for o in train_objects]

    train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
                                           input_shape=[125, 125, 3],
                                           batch_size=BATCH_SIZE)

    # Preparing train set
    directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA/preprocessed"
    validation_objects = db_mobio.objects(protocol="male", groups="dev")
    validation_labels = [o.client_id for o in validation_objects]

    validation_file_names = [o.make_path(
        directory=directory,
        extension=".hdf5")
                             for o in validation_objects]

    validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
                                                input_shape=[125, 125, 3],
                                                batch_size=VALIDATION_BATCH_SIZE)
    # Preparing the architecture
    # LENET PAPER CHOPRA
    architecture = Chopra(seed=SEED)

    loss = ContrastiveLoss(contrastive_margin=50.)
    optimizer = tf.train.GradientDescentOptimizer(0.00001)
    trainer = SiameseTrainer(architecture=architecture,
                             loss=loss,
                             iterations=ITERATIONS,
                             snapshot=VALIDATION_TEST,
                             optimizer=optimizer)

    trainer.train(train_data_shuffler, validation_data_shuffler)
    #trainer.train(train_data_shuffler)