train_siamese_casia_webface.py 3.8 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 


"""
Simple script that trains CASIA WEBFACE

Usage:
  train_siamese_casia_webface.py [--batch-size=<arg> --validation-batch-size=<arg> --iterations=<arg> --validation-interval=<arg> --use-gpu]
  train_siamese_casia_webface.py -h | --help
Options:
  -h --help     Show this screen.
  --batch-size=<arg>  [default: 1]
  --validation-batch-size=<arg>   [default:128]
  --iterations=<arg>  [default: 30000]
  --validation-interval=<arg>  [default: 100]
"""

from docopt import docopt
import tensorflow as tf
from .. import util
SEED = 10
25
from bob.learn.tensorflow.datashuffler import TripletDisk, TripletWithSelectionDisk, TripletWithFastSelectionDisk
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
26
from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra, Dummy
27 28
from bob.learn.tensorflow.trainers import SiameseTrainer, TripletTrainer
from bob.learn.tensorflow.loss import ContrastiveLoss, TripletLoss
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
import numpy


def main():
    args = docopt(__doc__, version='Mnist training with TensorFlow')

    BATCH_SIZE = int(args['--batch-size'])
    VALIDATION_BATCH_SIZE = int(args['--validation-batch-size'])
    ITERATIONS = int(args['--iterations'])
    VALIDATION_TEST = int(args['--validation-interval'])
    USE_GPU = args['--use-gpu']
    perc_train = 0.9

    import bob.db.mobio
    db_mobio = bob.db.mobio.Database()

    import bob.db.casia_webface
    db_casia = bob.db.casia_webface.Database()

    # Preparing train set
    train_objects = db_casia.objects(groups="world")
    #train_objects = db.objects(groups="world")
    train_labels = [int(o.client_id) for o in train_objects]
    directory = "/idiap/resource/database/CASIA-WebFace/CASIA-WebFace"

    train_file_names = [o.make_path(
        directory=directory,
        extension="")
                        for o in train_objects]

59 60 61 62 63 64 65 66
    #train_data_shuffler = TripletWithSelectionDisk(train_file_names, train_labels,
    #                                               input_shape=[125, 125, 3],
    #                                               batch_size=BATCH_SIZE)

    train_data_shuffler = TripletWithFastSelectionDisk(train_file_names, train_labels,
                                                       input_shape=[125, 125, 3],
                                                       batch_size=BATCH_SIZE)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
67 68 69 70 71 72 73 74 75 76 77

    # Preparing train set
    directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA/preprocessed"
    validation_objects = db_mobio.objects(protocol="male", groups="dev")
    validation_labels = [o.client_id for o in validation_objects]

    validation_file_names = [o.make_path(
        directory=directory,
        extension=".hdf5")
                             for o in validation_objects]

78 79 80
    validation_data_shuffler = TripletDisk(validation_file_names, validation_labels,
                                           input_shape=[125, 125, 3],
                                           batch_size=VALIDATION_BATCH_SIZE)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
81 82 83 84
    # Preparing the architecture
    # LENET PAPER CHOPRA
    architecture = Chopra(seed=SEED)

85 86 87 88 89 90 91 92
    #loss = ContrastiveLoss(contrastive_margin=50.)
    #optimizer = tf.train.GradientDescentOptimizer(0.00001)
    #trainer = SiameseTrainer(architecture=architecture,
    #                         loss=loss,
    #                         iterations=ITERATIONS,
    #                         snapshot=VALIDATION_TEST,
    #                         optimizer=optimizer)

93
    loss = TripletLoss(margin=1.)
94
    trainer = TripletTrainer(architecture=architecture, loss=loss,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
95
                             iterations=ITERATIONS,
96
                             prefetch=False,
97
                             temp_dir="./LOGS_CASIA/triplet-cnn-fast-selection")
98

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
99 100 101

    trainer.train(train_data_shuffler, validation_data_shuffler)
    #trainer.train(train_data_shuffler)