train_siamese_casia_webface.py 3.51 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 


"""
Simple script that trains CASIA WEBFACE

Usage:
  train_siamese_casia_webface.py [--batch-size=<arg> --validation-batch-size=<arg> --iterations=<arg> --validation-interval=<arg> --use-gpu]
  train_siamese_casia_webface.py -h | --help
Options:
  -h --help     Show this screen.
  --batch-size=<arg>  [default: 1]
  --validation-batch-size=<arg>   [default:128]
  --iterations=<arg>  [default: 30000]
  --validation-interval=<arg>  [default: 100]
"""

from docopt import docopt
import tensorflow as tf
from .. import util
SEED = 10
25
from bob.learn.tensorflow.datashuffler import TripletDisk, TripletWithSelectionDisk
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
26
from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra, Dummy
27 28
from bob.learn.tensorflow.trainers import SiameseTrainer, TripletTrainer
from bob.learn.tensorflow.loss import ContrastiveLoss, TripletLoss
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
import numpy


def main():
    args = docopt(__doc__, version='Mnist training with TensorFlow')

    BATCH_SIZE = int(args['--batch-size'])
    VALIDATION_BATCH_SIZE = int(args['--validation-batch-size'])
    ITERATIONS = int(args['--iterations'])
    VALIDATION_TEST = int(args['--validation-interval'])
    USE_GPU = args['--use-gpu']
    perc_train = 0.9

    import bob.db.mobio
    db_mobio = bob.db.mobio.Database()

    import bob.db.casia_webface
    db_casia = bob.db.casia_webface.Database()

    # Preparing train set
    train_objects = db_casia.objects(groups="world")
    #train_objects = db.objects(groups="world")
    train_labels = [int(o.client_id) for o in train_objects]
    directory = "/idiap/resource/database/CASIA-WebFace/CASIA-WebFace"

    train_file_names = [o.make_path(
        directory=directory,
        extension="")
                        for o in train_objects]

59 60 61
    train_data_shuffler = TripletWithSelectionDisk(train_file_names, train_labels,
                                                   input_shape=[125, 125, 3],
                                                   batch_size=BATCH_SIZE)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
62 63 64 65 66 67 68 69 70 71 72

    # Preparing train set
    directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA/preprocessed"
    validation_objects = db_mobio.objects(protocol="male", groups="dev")
    validation_labels = [o.client_id for o in validation_objects]

    validation_file_names = [o.make_path(
        directory=directory,
        extension=".hdf5")
                             for o in validation_objects]

73 74 75
    validation_data_shuffler = TripletDisk(validation_file_names, validation_labels,
                                           input_shape=[125, 125, 3],
                                           batch_size=VALIDATION_BATCH_SIZE)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
76 77 78 79
    # Preparing the architecture
    # LENET PAPER CHOPRA
    architecture = Chopra(seed=SEED)

80 81 82 83 84 85 86 87 88 89
    #loss = ContrastiveLoss(contrastive_margin=50.)
    #optimizer = tf.train.GradientDescentOptimizer(0.00001)
    #trainer = SiameseTrainer(architecture=architecture,
    #                         loss=loss,
    #                         iterations=ITERATIONS,
    #                         snapshot=VALIDATION_TEST,
    #                         optimizer=optimizer)

    loss = TripletLoss(margin=4.)
    trainer = TripletTrainer(architecture=architecture, loss=loss,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
90
                             iterations=ITERATIONS,
91 92 93
                             prefetch=False,
                             temp_dir="./LOGS_CASIA/triplet-cnn")

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
94 95 96

    trainer.train(train_data_shuffler, validation_data_shuffler)
    #trainer.train(train_data_shuffler)