From 14147b9a91d50392e961067c40b47c16352501c8 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Fri, 7 Oct 2016 16:29:16 +0200
Subject: [PATCH] CASIA Training

---
 .../script/train_siamese_casia_webface.py     | 89 +++++++++++++++++++
 1 file changed, 89 insertions(+)
 create mode 100644 bob/learn/tensorflow/script/train_siamese_casia_webface.py

diff --git a/bob/learn/tensorflow/script/train_siamese_casia_webface.py b/bob/learn/tensorflow/script/train_siamese_casia_webface.py
new file mode 100644
index 00000000..9506fb4b
--- /dev/null
+++ b/bob/learn/tensorflow/script/train_siamese_casia_webface.py
@@ -0,0 +1,89 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+# @date: Wed 11 May 2016 09:39:36 CEST 
+
+
+"""
+Simple script that trains CASIA WEBFACE
+
+Usage:
+  train_siamese_casia_webface.py [--batch-size=<arg> --validation-batch-size=<arg> --iterations=<arg> --validation-interval=<arg> --use-gpu]
+  train_siamese_casia_webface.py -h | --help
+Options:
+  -h --help     Show this screen.
+  --batch-size=<arg>  [default: 1]
+  --validation-batch-size=<arg>   [default:128]
+  --iterations=<arg>  [default: 30000]
+  --validation-interval=<arg>  [default: 100]
+"""
+
+from docopt import docopt
+import tensorflow as tf
+from .. import util
+SEED = 10
+from bob.learn.tensorflow.data import MemoryDataShuffler, TextDataShuffler
+from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra, Dummy
+from bob.learn.tensorflow.trainers import SiameseTrainer
+from bob.learn.tensorflow.loss import ContrastiveLoss
+import numpy
+
+
+def main():
+    args = docopt(__doc__, version='Mnist training with TensorFlow')
+
+    BATCH_SIZE = int(args['--batch-size'])
+    VALIDATION_BATCH_SIZE = int(args['--validation-batch-size'])
+    ITERATIONS = int(args['--iterations'])
+    VALIDATION_TEST = int(args['--validation-interval'])
+    USE_GPU = args['--use-gpu']
+    perc_train = 0.9
+
+    import bob.db.mobio
+    db_mobio = bob.db.mobio.Database()
+
+    import bob.db.casia_webface
+    db_casia = bob.db.casia_webface.Database()
+
+    # Preparing train set
+    train_objects = db_casia.objects(groups="world")
+    #train_objects = db.objects(groups="world")
+    train_labels = [int(o.client_id) for o in train_objects]
+    directory = "/idiap/resource/database/CASIA-WebFace/CASIA-WebFace"
+
+    train_file_names = [o.make_path(
+        directory=directory,
+        extension="")
+                        for o in train_objects]
+
+    train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
+                                           input_shape=[125, 125, 3],
+                                           batch_size=BATCH_SIZE)
+
+    # Preparing train set
+    directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA/preprocessed"
+    validation_objects = db_mobio.objects(protocol="male", groups="dev")
+    validation_labels = [o.client_id for o in validation_objects]
+
+    validation_file_names = [o.make_path(
+        directory=directory,
+        extension=".hdf5")
+                             for o in validation_objects]
+
+    validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
+                                                input_shape=[125, 125, 3],
+                                                batch_size=VALIDATION_BATCH_SIZE)
+    # Preparing the architecture
+    # LENET PAPER CHOPRA
+    architecture = Chopra(seed=SEED)
+
+    loss = ContrastiveLoss(contrastive_margin=50.)
+    optimizer = tf.train.GradientDescentOptimizer(0.00001)
+    trainer = SiameseTrainer(architecture=architecture,
+                             loss=loss,
+                             iterations=ITERATIONS,
+                             snapshot=VALIDATION_TEST,
+                             optimizer=optimizer)
+
+    trainer.train(train_data_shuffler, validation_data_shuffler)
+    #trainer.train(train_data_shuffler)
-- 
GitLab