From 2d03b381f64e4e40d82f422a666133593e8856b5 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Fri, 23 Sep 2016 17:32:06 +0200
Subject: [PATCH] Fixed synchronization bug

---
 bob/learn/tensorflow/analyzers/__init__.py    |  2 +-
 bob/learn/tensorflow/data/BaseDataShuffler.py |  7 +-
 bob/learn/tensorflow/script/train_mnist.py    |  4 +-
 .../tensorflow/script/train_mnist_siamese.py  | 70 ++++++++++++-------
 .../tensorflow/trainers/SiameseTrainer.py     | 11 +--
 bob/learn/tensorflow/trainers/Trainer.py      |  3 +-
 buildout.cfg                                  | 10 ++-
 7 files changed, 68 insertions(+), 39 deletions(-)

diff --git a/bob/learn/tensorflow/analyzers/__init__.py b/bob/learn/tensorflow/analyzers/__init__.py
index e170e6fc..7c594bde 100644
--- a/bob/learn/tensorflow/analyzers/__init__.py
+++ b/bob/learn/tensorflow/analyzers/__init__.py
@@ -2,7 +2,7 @@
 from pkgutil import extend_path
 __path__ = extend_path(__path__, __name__)
 
-from .Analizer import Analizer
+from .ExperimentAnalizer import ExperimentAnalizer
 
 # gets sphinx autodoc done right - don't remove it
 __all__ = [_ for _ in dir() if not _.startswith('_')]
diff --git a/bob/learn/tensorflow/data/BaseDataShuffler.py b/bob/learn/tensorflow/data/BaseDataShuffler.py
index a7dd245a..a57d1ce6 100644
--- a/bob/learn/tensorflow/data/BaseDataShuffler.py
+++ b/bob/learn/tensorflow/data/BaseDataShuffler.py
@@ -66,10 +66,11 @@ class BaseDataShuffler(object):
         return data, labels
 
     def get_genuine_or_not(self, input_data, input_labels, genuine=True):
+
         if genuine:
             # Getting a client
             index = numpy.random.randint(len(self.possible_labels))
-            index = self.possible_labels[index]
+            index = int(self.possible_labels[index])
 
             # Getting the indexes of the data from a particular client
             indexes = numpy.where(input_labels == index)[0]
@@ -82,8 +83,8 @@ class BaseDataShuffler(object):
         else:
             # Picking a pair of labels from different clients
             index = numpy.random.choice(len(self.possible_labels), 2, replace=False)
-            index[0] = self.possible_labels[index[0]]
-            index[1] = self.possible_labels[index[1]]
+            index[0] = self.possible_labels[int(index[0])]
+            index[1] = self.possible_labels[int(index[1])]
 
             # Getting the indexes of the two clients
             indexes = numpy.where(input_labels == index[0])[0]
diff --git a/bob/learn/tensorflow/script/train_mnist.py b/bob/learn/tensorflow/script/train_mnist.py
index a13c575d..a8db168f 100644
--- a/bob/learn/tensorflow/script/train_mnist.py
+++ b/bob/learn/tensorflow/script/train_mnist.py
@@ -89,8 +89,8 @@ def main():
     # Preparing the architecture
     cnn = True
     if cnn:
-        #architecture = Lenet(seed=SEED)
-        architecture = Dummy(seed=SEED)
+        architecture = Lenet(seed=SEED)
+        #architecture = Dummy(seed=SEED)
         loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
         trainer = Trainer(architecture=architecture, loss=loss, iterations=ITERATIONS)
         trainer.train(train_data_shuffler, validation_data_shuffler)
diff --git a/bob/learn/tensorflow/script/train_mnist_siamese.py b/bob/learn/tensorflow/script/train_mnist_siamese.py
index 047a9c7e..2d88ebce 100644
--- a/bob/learn/tensorflow/script/train_mnist_siamese.py
+++ b/bob/learn/tensorflow/script/train_mnist_siamese.py
@@ -23,7 +23,7 @@ import tensorflow as tf
 from .. import util
 SEED = 10
 from bob.learn.tensorflow.data import MemoryDataShuffler, TextDataShuffler
-from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra
+from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra, Dummy
 from bob.learn.tensorflow.trainers import SiameseTrainer
 from bob.learn.tensorflow.loss import ContrastiveLoss
 import numpy
@@ -39,7 +39,7 @@ def main():
     perc_train = 0.9
 
     # Loading data
-    mnist = True
+    mnist = False
 
     if mnist:
         train_data, train_labels, validation_data, validation_labels = \
@@ -58,55 +58,70 @@ def main():
                                                       batch_size=VALIDATION_BATCH_SIZE)
 
     else:
-        import bob.db.atnt
-        db = bob.db.atnt.Database()
+        import bob.db.mobio
+        db_mobio = bob.db.mobio.Database()
 
-        #import bob.db.mobio
-        #db = bob.db.mobio.Database()
+        import bob.db.casia_webface
+        db_casia = bob.db.casia_webface.Database()
 
         # Preparing train set
-        #train_objects = db.objects(protocol="male", groups="world")
-        train_objects = db.objects(groups="world")
-        train_labels = [o.client_id for o in train_objects]
-        #directory = "/idiap/user/tpereira/face/baselines/eigenface/preprocessed",
+        train_objects = db_casia.objects(groups="world")
+        #train_objects = db.objects(groups="world")
+        train_labels = [int(o.client_id) for o in train_objects]
+        directory = "/idiap/resource/database/CASIA-WebFace/CASIA-WebFace"
+
         train_file_names = [o.make_path(
-            directory="/idiap/group/biometric/databases/orl",
-            extension=".pgm")
+            directory=directory,
+            extension="")
                             for o in train_objects]
+        #import ipdb;
+        #ipdb.set_trace();
+
+        #train_file_names = [o.make_path(
+        #    directory="/idiap/group/biometric/databases/orl",
+        #    extension=".pgm")
+        #                    for o in train_objects]
 
-        #train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
-        #                                       input_shape=[80, 64, 1],
-        #                                       batch_size=BATCH_SIZE)
         train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
-                                               input_shape=[56, 46, 1],
+                                               input_shape=[250, 250, 3],
                                                batch_size=BATCH_SIZE)
 
+        #train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
+        #                                       input_shape=[56, 46, 1],
+        #                                       batch_size=BATCH_SIZE)
+
         # Preparing train set
-        #validation_objects = db.objects(protocol="male", groups="dev")
-        validation_objects = db.objects(groups="dev")
+        directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA/preprocessed"
+        validation_objects = db_mobio.objects(protocol="male", groups="dev")
         validation_labels = [o.client_id for o in validation_objects]
+        #validation_file_names = [o.make_path(
+        #    directory="/idiap/group/biometric/databases/orl",
+        #    extension=".pgm")
+        #                         for o in validation_objects]
+
         validation_file_names = [o.make_path(
-            directory="/idiap/group/biometric/databases/orl",
-            extension=".pgm")
+            directory=directory,
+            extension=".hdf5")
                                  for o in validation_objects]
 
-        #validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
-        #                                           input_shape=[80, 64, 1],
-        #                                            batch_size=VALIDATION_BATCH_SIZE)
         validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
-                                                    input_shape=[56, 46, 1],
+                                                    input_shape=[250, 250, 3],
                                                     batch_size=VALIDATION_BATCH_SIZE)
+        #validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
+        #                                            input_shape=[56, 46, 1],
+        #                                            batch_size=VALIDATION_BATCH_SIZE)
 
     # Preparing the architecture
     n_classes = len(train_data_shuffler.possible_labels)
-
+    n_classes = 200
     cnn = True
     if cnn:
 
         # LENET PAPER CHOPRA
         #architecture = Chopra(default_feature_layer="fc7")
-        architecture = Lenet(default_feature_layer="fc2", n_classes=n_classes, conv1_output=4, conv2_output=8,use_gpu=USE_GPU)
+        architecture = Lenet(default_feature_layer="fc2", n_classes=n_classes, conv1_output=8, conv2_output=16,use_gpu=USE_GPU)
         #architecture = VGG(n_classes=n_classes, use_gpu=USE_GPU)
+        #architecture = Dummy(seed=SEED)
 
         #architecture = LenetDropout(default_feature_layer="fc2", n_classes=n_classes, conv1_output=4, conv2_output=8, use_gpu=USE_GPU)
 
@@ -115,7 +130,8 @@ def main():
         trainer = SiameseTrainer(architecture=architecture,
                                  loss=loss,
                                  iterations=ITERATIONS,
-                                 snapshot=VALIDATION_TEST)
+                                 snapshot=VALIDATION_TEST,
+                                 )
         trainer.train(train_data_shuffler, validation_data_shuffler)
     else:
         mlp = MLP(n_classes, hidden_layers=[15, 20])
diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py
index 15d79e34..c43d9990 100644
--- a/bob/learn/tensorflow/trainers/SiameseTrainer.py
+++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py
@@ -12,7 +12,7 @@ from ..network import SequenceNetwork
 import bob.io.base
 from .Trainer import Trainer
 import os
-
+import sys
 
 class SiameseTrainer(Trainer):
 
@@ -64,7 +64,8 @@ class SiameseTrainer(Trainer):
             """
             Injecting data in the place holder queue
             """
-            for i in range(self.iterations):
+            #for i in range(self.iterations+5):
+            while not thread_pool.should_stop():
                 batch_left, batch_right, labels = train_data_shuffler.get_pair()
 
                 feed_dict = {train_placeholder_left_data: batch_left,
@@ -151,13 +152,13 @@ class SiameseTrainer(Trainer):
             self.architecture.generate_summaries()
             merged_validation = tf.merge_all_summaries()
 
-
-
             for step in range(self.iterations):
 
                 _, l, lr, summary = session.run([optimizer, loss_train, learning_rate, merged])
                 #_, l, lr= session.run([optimizer, loss_train, learning_rate])
                 train_writer.add_summary(summary, step)
+                print str(step)
+                sys.stdout.flush()
 
                 if validation_data_shuffler is not None and step % self.snapshot == 0:
 
@@ -167,7 +168,9 @@ class SiameseTrainer(Trainer):
                     summary = analizer()
                     train_writer.add_summary(summary, step)
                     print str(step)
+                sys.stdout.flush()
 
+            print("#######DONE##########")
             self.architecture.save(hdf5)
             del hdf5
             train_writer.close()
diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py
index abf2c6cf..e66bc6ba 100644
--- a/bob/learn/tensorflow/trainers/Trainer.py
+++ b/bob/learn/tensorflow/trainers/Trainer.py
@@ -79,7 +79,8 @@ class Trainer(object):
             """
 
             #while not thread_pool.should_stop():
-            for i in range(self.iterations):
+            #for i in range(self.iterations):
+            while not thread_pool.should_stop():
                 train_data, train_labels = train_data_shuffler.get_batch()
 
                 feed_dict = {train_placeholder_data: train_data,
diff --git a/buildout.cfg b/buildout.cfg
index b47a6d3b..95369876 100644
--- a/buildout.cfg
+++ b/buildout.cfg
@@ -5,6 +5,7 @@
 [buildout]
 parts = scripts
 eggs = bob.learn.tensorflow
+       bob.db.casia_webface
        gridtk
 
 extensions = bob.buildout
@@ -12,6 +13,9 @@ extensions = bob.buildout
 auto-checkout = *
 develop = src/bob.db.mnist
           src/gridtk
+          src/bob.db.casia_webface
+          src/bob.db.mobio
+          src/bob.db.lfw
           .
 
 ; options for bob.buildout
@@ -21,7 +25,11 @@ newest = false
 
 
 [sources]
-bob.db.mnist = git git@github.com:tiagofrepereira2012/bob.db.mnist
+bob.db.mnist = git git@github.com:tiagofrepereira2012/bob.db.mnist.git
+bob.db.base = git git@gitlab.idiap.ch:bob/bob.db.base.git
+bob.db.mobio = git git@gitlab.idiap.ch:bob/bob.db.mobio.git
+bob.db.lfw = git git@gitlab.idiap.ch:bob/bob.db.lfw.git
+bob.db.casia_webface = git git@gitlab.idiap.ch:bob/bob.db.casia_webface.git
 gridtk = git git@github.com:bioidiap/gridtk
 
 
-- 
GitLab