From 2d03b381f64e4e40d82f422a666133593e8856b5 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Fri, 23 Sep 2016 17:32:06 +0200 Subject: [PATCH] Fixed synchronization bug --- bob/learn/tensorflow/analyzers/__init__.py | 2 +- bob/learn/tensorflow/data/BaseDataShuffler.py | 7 +- bob/learn/tensorflow/script/train_mnist.py | 4 +- .../tensorflow/script/train_mnist_siamese.py | 70 ++++++++++++------- .../tensorflow/trainers/SiameseTrainer.py | 11 +-- bob/learn/tensorflow/trainers/Trainer.py | 3 +- buildout.cfg | 10 ++- 7 files changed, 68 insertions(+), 39 deletions(-) diff --git a/bob/learn/tensorflow/analyzers/__init__.py b/bob/learn/tensorflow/analyzers/__init__.py index e170e6fc..7c594bde 100644 --- a/bob/learn/tensorflow/analyzers/__init__.py +++ b/bob/learn/tensorflow/analyzers/__init__.py @@ -2,7 +2,7 @@ from pkgutil import extend_path __path__ = extend_path(__path__, __name__) -from .Analizer import Analizer +from .ExperimentAnalizer import ExperimentAnalizer # gets sphinx autodoc done right - don't remove it __all__ = [_ for _ in dir() if not _.startswith('_')] diff --git a/bob/learn/tensorflow/data/BaseDataShuffler.py b/bob/learn/tensorflow/data/BaseDataShuffler.py index a7dd245a..a57d1ce6 100644 --- a/bob/learn/tensorflow/data/BaseDataShuffler.py +++ b/bob/learn/tensorflow/data/BaseDataShuffler.py @@ -66,10 +66,11 @@ class BaseDataShuffler(object): return data, labels def get_genuine_or_not(self, input_data, input_labels, genuine=True): + if genuine: # Getting a client index = numpy.random.randint(len(self.possible_labels)) - index = self.possible_labels[index] + index = int(self.possible_labels[index]) # Getting the indexes of the data from a particular client indexes = numpy.where(input_labels == index)[0] @@ -82,8 +83,8 @@ class BaseDataShuffler(object): else: # Picking a pair of labels from different clients index = numpy.random.choice(len(self.possible_labels), 2, replace=False) - index[0] = self.possible_labels[index[0]] - index[1] = self.possible_labels[index[1]] + index[0] = self.possible_labels[int(index[0])] + index[1] = self.possible_labels[int(index[1])] # Getting the indexes of the two clients indexes = numpy.where(input_labels == index[0])[0] diff --git a/bob/learn/tensorflow/script/train_mnist.py b/bob/learn/tensorflow/script/train_mnist.py index a13c575d..a8db168f 100644 --- a/bob/learn/tensorflow/script/train_mnist.py +++ b/bob/learn/tensorflow/script/train_mnist.py @@ -89,8 +89,8 @@ def main(): # Preparing the architecture cnn = True if cnn: - #architecture = Lenet(seed=SEED) - architecture = Dummy(seed=SEED) + architecture = Lenet(seed=SEED) + #architecture = Dummy(seed=SEED) loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean) trainer = Trainer(architecture=architecture, loss=loss, iterations=ITERATIONS) trainer.train(train_data_shuffler, validation_data_shuffler) diff --git a/bob/learn/tensorflow/script/train_mnist_siamese.py b/bob/learn/tensorflow/script/train_mnist_siamese.py index 047a9c7e..2d88ebce 100644 --- a/bob/learn/tensorflow/script/train_mnist_siamese.py +++ b/bob/learn/tensorflow/script/train_mnist_siamese.py @@ -23,7 +23,7 @@ import tensorflow as tf from .. import util SEED = 10 from bob.learn.tensorflow.data import MemoryDataShuffler, TextDataShuffler -from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra +from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra, Dummy from bob.learn.tensorflow.trainers import SiameseTrainer from bob.learn.tensorflow.loss import ContrastiveLoss import numpy @@ -39,7 +39,7 @@ def main(): perc_train = 0.9 # Loading data - mnist = True + mnist = False if mnist: train_data, train_labels, validation_data, validation_labels = \ @@ -58,55 +58,70 @@ def main(): batch_size=VALIDATION_BATCH_SIZE) else: - import bob.db.atnt - db = bob.db.atnt.Database() + import bob.db.mobio + db_mobio = bob.db.mobio.Database() - #import bob.db.mobio - #db = bob.db.mobio.Database() + import bob.db.casia_webface + db_casia = bob.db.casia_webface.Database() # Preparing train set - #train_objects = db.objects(protocol="male", groups="world") - train_objects = db.objects(groups="world") - train_labels = [o.client_id for o in train_objects] - #directory = "/idiap/user/tpereira/face/baselines/eigenface/preprocessed", + train_objects = db_casia.objects(groups="world") + #train_objects = db.objects(groups="world") + train_labels = [int(o.client_id) for o in train_objects] + directory = "/idiap/resource/database/CASIA-WebFace/CASIA-WebFace" + train_file_names = [o.make_path( - directory="/idiap/group/biometric/databases/orl", - extension=".pgm") + directory=directory, + extension="") for o in train_objects] + #import ipdb; + #ipdb.set_trace(); + + #train_file_names = [o.make_path( + # directory="/idiap/group/biometric/databases/orl", + # extension=".pgm") + # for o in train_objects] - #train_data_shuffler = TextDataShuffler(train_file_names, train_labels, - # input_shape=[80, 64, 1], - # batch_size=BATCH_SIZE) train_data_shuffler = TextDataShuffler(train_file_names, train_labels, - input_shape=[56, 46, 1], + input_shape=[250, 250, 3], batch_size=BATCH_SIZE) + #train_data_shuffler = TextDataShuffler(train_file_names, train_labels, + # input_shape=[56, 46, 1], + # batch_size=BATCH_SIZE) + # Preparing train set - #validation_objects = db.objects(protocol="male", groups="dev") - validation_objects = db.objects(groups="dev") + directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA/preprocessed" + validation_objects = db_mobio.objects(protocol="male", groups="dev") validation_labels = [o.client_id for o in validation_objects] + #validation_file_names = [o.make_path( + # directory="/idiap/group/biometric/databases/orl", + # extension=".pgm") + # for o in validation_objects] + validation_file_names = [o.make_path( - directory="/idiap/group/biometric/databases/orl", - extension=".pgm") + directory=directory, + extension=".hdf5") for o in validation_objects] - #validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels, - # input_shape=[80, 64, 1], - # batch_size=VALIDATION_BATCH_SIZE) validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels, - input_shape=[56, 46, 1], + input_shape=[250, 250, 3], batch_size=VALIDATION_BATCH_SIZE) + #validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels, + # input_shape=[56, 46, 1], + # batch_size=VALIDATION_BATCH_SIZE) # Preparing the architecture n_classes = len(train_data_shuffler.possible_labels) - + n_classes = 200 cnn = True if cnn: # LENET PAPER CHOPRA #architecture = Chopra(default_feature_layer="fc7") - architecture = Lenet(default_feature_layer="fc2", n_classes=n_classes, conv1_output=4, conv2_output=8,use_gpu=USE_GPU) + architecture = Lenet(default_feature_layer="fc2", n_classes=n_classes, conv1_output=8, conv2_output=16,use_gpu=USE_GPU) #architecture = VGG(n_classes=n_classes, use_gpu=USE_GPU) + #architecture = Dummy(seed=SEED) #architecture = LenetDropout(default_feature_layer="fc2", n_classes=n_classes, conv1_output=4, conv2_output=8, use_gpu=USE_GPU) @@ -115,7 +130,8 @@ def main(): trainer = SiameseTrainer(architecture=architecture, loss=loss, iterations=ITERATIONS, - snapshot=VALIDATION_TEST) + snapshot=VALIDATION_TEST, + ) trainer.train(train_data_shuffler, validation_data_shuffler) else: mlp = MLP(n_classes, hidden_layers=[15, 20]) diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py index 15d79e34..c43d9990 100644 --- a/bob/learn/tensorflow/trainers/SiameseTrainer.py +++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py @@ -12,7 +12,7 @@ from ..network import SequenceNetwork import bob.io.base from .Trainer import Trainer import os - +import sys class SiameseTrainer(Trainer): @@ -64,7 +64,8 @@ class SiameseTrainer(Trainer): """ Injecting data in the place holder queue """ - for i in range(self.iterations): + #for i in range(self.iterations+5): + while not thread_pool.should_stop(): batch_left, batch_right, labels = train_data_shuffler.get_pair() feed_dict = {train_placeholder_left_data: batch_left, @@ -151,13 +152,13 @@ class SiameseTrainer(Trainer): self.architecture.generate_summaries() merged_validation = tf.merge_all_summaries() - - for step in range(self.iterations): _, l, lr, summary = session.run([optimizer, loss_train, learning_rate, merged]) #_, l, lr= session.run([optimizer, loss_train, learning_rate]) train_writer.add_summary(summary, step) + print str(step) + sys.stdout.flush() if validation_data_shuffler is not None and step % self.snapshot == 0: @@ -167,7 +168,9 @@ class SiameseTrainer(Trainer): summary = analizer() train_writer.add_summary(summary, step) print str(step) + sys.stdout.flush() + print("#######DONE##########") self.architecture.save(hdf5) del hdf5 train_writer.close() diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py index abf2c6cf..e66bc6ba 100644 --- a/bob/learn/tensorflow/trainers/Trainer.py +++ b/bob/learn/tensorflow/trainers/Trainer.py @@ -79,7 +79,8 @@ class Trainer(object): """ #while not thread_pool.should_stop(): - for i in range(self.iterations): + #for i in range(self.iterations): + while not thread_pool.should_stop(): train_data, train_labels = train_data_shuffler.get_batch() feed_dict = {train_placeholder_data: train_data, diff --git a/buildout.cfg b/buildout.cfg index b47a6d3b..95369876 100644 --- a/buildout.cfg +++ b/buildout.cfg @@ -5,6 +5,7 @@ [buildout] parts = scripts eggs = bob.learn.tensorflow + bob.db.casia_webface gridtk extensions = bob.buildout @@ -12,6 +13,9 @@ extensions = bob.buildout auto-checkout = * develop = src/bob.db.mnist src/gridtk + src/bob.db.casia_webface + src/bob.db.mobio + src/bob.db.lfw . ; options for bob.buildout @@ -21,7 +25,11 @@ newest = false [sources] -bob.db.mnist = git git@github.com:tiagofrepereira2012/bob.db.mnist +bob.db.mnist = git git@github.com:tiagofrepereira2012/bob.db.mnist.git +bob.db.base = git git@gitlab.idiap.ch:bob/bob.db.base.git +bob.db.mobio = git git@gitlab.idiap.ch:bob/bob.db.mobio.git +bob.db.lfw = git git@gitlab.idiap.ch:bob/bob.db.lfw.git +bob.db.casia_webface = git git@gitlab.idiap.ch:bob/bob.db.casia_webface.git gridtk = git git@github.com:bioidiap/gridtk -- GitLab