From 14147b9a91d50392e961067c40b47c16352501c8 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Fri, 7 Oct 2016 16:29:16 +0200 Subject: [PATCH] CASIA Training --- .../script/train_siamese_casia_webface.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 bob/learn/tensorflow/script/train_siamese_casia_webface.py diff --git a/bob/learn/tensorflow/script/train_siamese_casia_webface.py b/bob/learn/tensorflow/script/train_siamese_casia_webface.py new file mode 100644 index 00000000..9506fb4b --- /dev/null +++ b/bob/learn/tensorflow/script/train_siamese_casia_webface.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> +# @date: Wed 11 May 2016 09:39:36 CEST + + +""" +Simple script that trains CASIA WEBFACE + +Usage: + train_siamese_casia_webface.py [--batch-size=<arg> --validation-batch-size=<arg> --iterations=<arg> --validation-interval=<arg> --use-gpu] + train_siamese_casia_webface.py -h | --help +Options: + -h --help Show this screen. + --batch-size=<arg> [default: 1] + --validation-batch-size=<arg> [default:128] + --iterations=<arg> [default: 30000] + --validation-interval=<arg> [default: 100] +""" + +from docopt import docopt +import tensorflow as tf +from .. import util +SEED = 10 +from bob.learn.tensorflow.data import MemoryDataShuffler, TextDataShuffler +from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra, Dummy +from bob.learn.tensorflow.trainers import SiameseTrainer +from bob.learn.tensorflow.loss import ContrastiveLoss +import numpy + + +def main(): + args = docopt(__doc__, version='Mnist training with TensorFlow') + + BATCH_SIZE = int(args['--batch-size']) + VALIDATION_BATCH_SIZE = int(args['--validation-batch-size']) + ITERATIONS = int(args['--iterations']) + VALIDATION_TEST = int(args['--validation-interval']) + USE_GPU = args['--use-gpu'] + perc_train = 0.9 + + import bob.db.mobio + db_mobio = bob.db.mobio.Database() + + import bob.db.casia_webface + db_casia = bob.db.casia_webface.Database() + + # Preparing train set + train_objects = db_casia.objects(groups="world") + #train_objects = db.objects(groups="world") + train_labels = [int(o.client_id) for o in train_objects] + directory = "/idiap/resource/database/CASIA-WebFace/CASIA-WebFace" + + train_file_names = [o.make_path( + directory=directory, + extension="") + for o in train_objects] + + train_data_shuffler = TextDataShuffler(train_file_names, train_labels, + input_shape=[125, 125, 3], + batch_size=BATCH_SIZE) + + # Preparing train set + directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA/preprocessed" + validation_objects = db_mobio.objects(protocol="male", groups="dev") + validation_labels = [o.client_id for o in validation_objects] + + validation_file_names = [o.make_path( + directory=directory, + extension=".hdf5") + for o in validation_objects] + + validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels, + input_shape=[125, 125, 3], + batch_size=VALIDATION_BATCH_SIZE) + # Preparing the architecture + # LENET PAPER CHOPRA + architecture = Chopra(seed=SEED) + + loss = ContrastiveLoss(contrastive_margin=50.) + optimizer = tf.train.GradientDescentOptimizer(0.00001) + trainer = SiameseTrainer(architecture=architecture, + loss=loss, + iterations=ITERATIONS, + snapshot=VALIDATION_TEST, + optimizer=optimizer) + + trainer.train(train_data_shuffler, validation_data_shuffler) + #trainer.train(train_data_shuffler) -- GitLab