Fixed synchronization bug

parent 8bc92f21
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from pkgutil import extend_path from pkgutil import extend_path
__path__ = extend_path(__path__, __name__) __path__ = extend_path(__path__, __name__)
from .Analizer import Analizer from .ExperimentAnalizer import ExperimentAnalizer
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')] __all__ = [_ for _ in dir() if not _.startswith('_')]
......
...@@ -66,10 +66,11 @@ class BaseDataShuffler(object): ...@@ -66,10 +66,11 @@ class BaseDataShuffler(object):
return data, labels return data, labels
def get_genuine_or_not(self, input_data, input_labels, genuine=True): def get_genuine_or_not(self, input_data, input_labels, genuine=True):
if genuine: if genuine:
# Getting a client # Getting a client
index = numpy.random.randint(len(self.possible_labels)) index = numpy.random.randint(len(self.possible_labels))
index = self.possible_labels[index] index = int(self.possible_labels[index])
# Getting the indexes of the data from a particular client # Getting the indexes of the data from a particular client
indexes = numpy.where(input_labels == index)[0] indexes = numpy.where(input_labels == index)[0]
...@@ -82,8 +83,8 @@ class BaseDataShuffler(object): ...@@ -82,8 +83,8 @@ class BaseDataShuffler(object):
else: else:
# Picking a pair of labels from different clients # Picking a pair of labels from different clients
index = numpy.random.choice(len(self.possible_labels), 2, replace=False) index = numpy.random.choice(len(self.possible_labels), 2, replace=False)
index[0] = self.possible_labels[index[0]] index[0] = self.possible_labels[int(index[0])]
index[1] = self.possible_labels[index[1]] index[1] = self.possible_labels[int(index[1])]
# Getting the indexes of the two clients # Getting the indexes of the two clients
indexes = numpy.where(input_labels == index[0])[0] indexes = numpy.where(input_labels == index[0])[0]
......
...@@ -89,8 +89,8 @@ def main(): ...@@ -89,8 +89,8 @@ def main():
# Preparing the architecture # Preparing the architecture
cnn = True cnn = True
if cnn: if cnn:
#architecture = Lenet(seed=SEED) architecture = Lenet(seed=SEED)
architecture = Dummy(seed=SEED) #architecture = Dummy(seed=SEED)
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean) loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
trainer = Trainer(architecture=architecture, loss=loss, iterations=ITERATIONS) trainer = Trainer(architecture=architecture, loss=loss, iterations=ITERATIONS)
trainer.train(train_data_shuffler, validation_data_shuffler) trainer.train(train_data_shuffler, validation_data_shuffler)
......
...@@ -23,7 +23,7 @@ import tensorflow as tf ...@@ -23,7 +23,7 @@ import tensorflow as tf
from .. import util from .. import util
SEED = 10 SEED = 10
from bob.learn.tensorflow.data import MemoryDataShuffler, TextDataShuffler from bob.learn.tensorflow.data import MemoryDataShuffler, TextDataShuffler
from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra, Dummy
from bob.learn.tensorflow.trainers import SiameseTrainer from bob.learn.tensorflow.trainers import SiameseTrainer
from bob.learn.tensorflow.loss import ContrastiveLoss from bob.learn.tensorflow.loss import ContrastiveLoss
import numpy import numpy
...@@ -39,7 +39,7 @@ def main(): ...@@ -39,7 +39,7 @@ def main():
perc_train = 0.9 perc_train = 0.9
# Loading data # Loading data
mnist = True mnist = False
if mnist: if mnist:
train_data, train_labels, validation_data, validation_labels = \ train_data, train_labels, validation_data, validation_labels = \
...@@ -58,55 +58,70 @@ def main(): ...@@ -58,55 +58,70 @@ def main():
batch_size=VALIDATION_BATCH_SIZE) batch_size=VALIDATION_BATCH_SIZE)
else: else:
import bob.db.atnt import bob.db.mobio
db = bob.db.atnt.Database() db_mobio = bob.db.mobio.Database()
#import bob.db.mobio import bob.db.casia_webface
#db = bob.db.mobio.Database() db_casia = bob.db.casia_webface.Database()
# Preparing train set # Preparing train set
#train_objects = db.objects(protocol="male", groups="world") train_objects = db_casia.objects(groups="world")
train_objects = db.objects(groups="world") #train_objects = db.objects(groups="world")
train_labels = [o.client_id for o in train_objects] train_labels = [int(o.client_id) for o in train_objects]
#directory = "/idiap/user/tpereira/face/baselines/eigenface/preprocessed", directory = "/idiap/resource/database/CASIA-WebFace/CASIA-WebFace"
train_file_names = [o.make_path( train_file_names = [o.make_path(
directory="/idiap/group/biometric/databases/orl", directory=directory,
extension=".pgm") extension="")
for o in train_objects] for o in train_objects]
#import ipdb;
#ipdb.set_trace();
#train_file_names = [o.make_path(
# directory="/idiap/group/biometric/databases/orl",
# extension=".pgm")
# for o in train_objects]
#train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
# input_shape=[80, 64, 1],
# batch_size=BATCH_SIZE)
train_data_shuffler = TextDataShuffler(train_file_names, train_labels, train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
input_shape=[56, 46, 1], input_shape=[250, 250, 3],
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
#train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
# input_shape=[56, 46, 1],
# batch_size=BATCH_SIZE)
# Preparing train set # Preparing train set
#validation_objects = db.objects(protocol="male", groups="dev") directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA/preprocessed"
validation_objects = db.objects(groups="dev") validation_objects = db_mobio.objects(protocol="male", groups="dev")
validation_labels = [o.client_id for o in validation_objects] validation_labels = [o.client_id for o in validation_objects]
#validation_file_names = [o.make_path(
# directory="/idiap/group/biometric/databases/orl",
# extension=".pgm")
# for o in validation_objects]
validation_file_names = [o.make_path( validation_file_names = [o.make_path(
directory="/idiap/group/biometric/databases/orl", directory=directory,
extension=".pgm") extension=".hdf5")
for o in validation_objects] for o in validation_objects]
#validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
# input_shape=[80, 64, 1],
# batch_size=VALIDATION_BATCH_SIZE)
validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels, validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
input_shape=[56, 46, 1], input_shape=[250, 250, 3],
batch_size=VALIDATION_BATCH_SIZE) batch_size=VALIDATION_BATCH_SIZE)
#validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
# input_shape=[56, 46, 1],
# batch_size=VALIDATION_BATCH_SIZE)
# Preparing the architecture # Preparing the architecture
n_classes = len(train_data_shuffler.possible_labels) n_classes = len(train_data_shuffler.possible_labels)
n_classes = 200
cnn = True cnn = True
if cnn: if cnn:
# LENET PAPER CHOPRA # LENET PAPER CHOPRA
#architecture = Chopra(default_feature_layer="fc7") #architecture = Chopra(default_feature_layer="fc7")
architecture = Lenet(default_feature_layer="fc2", n_classes=n_classes, conv1_output=4, conv2_output=8,use_gpu=USE_GPU) architecture = Lenet(default_feature_layer="fc2", n_classes=n_classes, conv1_output=8, conv2_output=16,use_gpu=USE_GPU)
#architecture = VGG(n_classes=n_classes, use_gpu=USE_GPU) #architecture = VGG(n_classes=n_classes, use_gpu=USE_GPU)
#architecture = Dummy(seed=SEED)
#architecture = LenetDropout(default_feature_layer="fc2", n_classes=n_classes, conv1_output=4, conv2_output=8, use_gpu=USE_GPU) #architecture = LenetDropout(default_feature_layer="fc2", n_classes=n_classes, conv1_output=4, conv2_output=8, use_gpu=USE_GPU)
...@@ -115,7 +130,8 @@ def main(): ...@@ -115,7 +130,8 @@ def main():
trainer = SiameseTrainer(architecture=architecture, trainer = SiameseTrainer(architecture=architecture,
loss=loss, loss=loss,
iterations=ITERATIONS, iterations=ITERATIONS,
snapshot=VALIDATION_TEST) snapshot=VALIDATION_TEST,
)
trainer.train(train_data_shuffler, validation_data_shuffler) trainer.train(train_data_shuffler, validation_data_shuffler)
else: else:
mlp = MLP(n_classes, hidden_layers=[15, 20]) mlp = MLP(n_classes, hidden_layers=[15, 20])
......
...@@ -12,7 +12,7 @@ from ..network import SequenceNetwork ...@@ -12,7 +12,7 @@ from ..network import SequenceNetwork
import bob.io.base import bob.io.base
from .Trainer import Trainer from .Trainer import Trainer
import os import os
import sys
class SiameseTrainer(Trainer): class SiameseTrainer(Trainer):
...@@ -64,7 +64,8 @@ class SiameseTrainer(Trainer): ...@@ -64,7 +64,8 @@ class SiameseTrainer(Trainer):
""" """
Injecting data in the place holder queue Injecting data in the place holder queue
""" """
for i in range(self.iterations): #for i in range(self.iterations+5):
while not thread_pool.should_stop():
batch_left, batch_right, labels = train_data_shuffler.get_pair() batch_left, batch_right, labels = train_data_shuffler.get_pair()
feed_dict = {train_placeholder_left_data: batch_left, feed_dict = {train_placeholder_left_data: batch_left,
...@@ -151,13 +152,13 @@ class SiameseTrainer(Trainer): ...@@ -151,13 +152,13 @@ class SiameseTrainer(Trainer):
self.architecture.generate_summaries() self.architecture.generate_summaries()
merged_validation = tf.merge_all_summaries() merged_validation = tf.merge_all_summaries()
for step in range(self.iterations): for step in range(self.iterations):
_, l, lr, summary = session.run([optimizer, loss_train, learning_rate, merged]) _, l, lr, summary = session.run([optimizer, loss_train, learning_rate, merged])
#_, l, lr= session.run([optimizer, loss_train, learning_rate]) #_, l, lr= session.run([optimizer, loss_train, learning_rate])
train_writer.add_summary(summary, step) train_writer.add_summary(summary, step)
print str(step)
sys.stdout.flush()
if validation_data_shuffler is not None and step % self.snapshot == 0: if validation_data_shuffler is not None and step % self.snapshot == 0:
...@@ -167,7 +168,9 @@ class SiameseTrainer(Trainer): ...@@ -167,7 +168,9 @@ class SiameseTrainer(Trainer):
summary = analizer() summary = analizer()
train_writer.add_summary(summary, step) train_writer.add_summary(summary, step)
print str(step) print str(step)
sys.stdout.flush()
print("#######DONE##########")
self.architecture.save(hdf5) self.architecture.save(hdf5)
del hdf5 del hdf5
train_writer.close() train_writer.close()
......
...@@ -79,7 +79,8 @@ class Trainer(object): ...@@ -79,7 +79,8 @@ class Trainer(object):
""" """
#while not thread_pool.should_stop(): #while not thread_pool.should_stop():
for i in range(self.iterations): #for i in range(self.iterations):
while not thread_pool.should_stop():
train_data, train_labels = train_data_shuffler.get_batch() train_data, train_labels = train_data_shuffler.get_batch()
feed_dict = {train_placeholder_data: train_data, feed_dict = {train_placeholder_data: train_data,
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
[buildout] [buildout]
parts = scripts parts = scripts
eggs = bob.learn.tensorflow eggs = bob.learn.tensorflow
bob.db.casia_webface
gridtk gridtk
extensions = bob.buildout extensions = bob.buildout
...@@ -12,6 +13,9 @@ extensions = bob.buildout ...@@ -12,6 +13,9 @@ extensions = bob.buildout
auto-checkout = * auto-checkout = *
develop = src/bob.db.mnist develop = src/bob.db.mnist
src/gridtk src/gridtk
src/bob.db.casia_webface
src/bob.db.mobio
src/bob.db.lfw
. .
; options for bob.buildout ; options for bob.buildout
...@@ -21,7 +25,11 @@ newest = false ...@@ -21,7 +25,11 @@ newest = false
[sources] [sources]
bob.db.mnist = git git@github.com:tiagofrepereira2012/bob.db.mnist bob.db.mnist = git git@github.com:tiagofrepereira2012/bob.db.mnist.git
bob.db.base = git git@gitlab.idiap.ch:bob/bob.db.base.git
bob.db.mobio = git git@gitlab.idiap.ch:bob/bob.db.mobio.git
bob.db.lfw = git git@gitlab.idiap.ch:bob/bob.db.lfw.git
bob.db.casia_webface = git git@gitlab.idiap.ch:bob/bob.db.casia_webface.git
gridtk = git git@github.com:bioidiap/gridtk gridtk = git git@github.com:bioidiap/gridtk
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment