Skip to content
Snippets Groups Projects
Commit 2d03b381 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed synchronization bug

parent 8bc92f21
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)
from .Analizer import Analizer
from .ExperimentAnalizer import ExperimentAnalizer
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
......
......@@ -66,10 +66,11 @@ class BaseDataShuffler(object):
return data, labels
def get_genuine_or_not(self, input_data, input_labels, genuine=True):
if genuine:
# Getting a client
index = numpy.random.randint(len(self.possible_labels))
index = self.possible_labels[index]
index = int(self.possible_labels[index])
# Getting the indexes of the data from a particular client
indexes = numpy.where(input_labels == index)[0]
......@@ -82,8 +83,8 @@ class BaseDataShuffler(object):
else:
# Picking a pair of labels from different clients
index = numpy.random.choice(len(self.possible_labels), 2, replace=False)
index[0] = self.possible_labels[index[0]]
index[1] = self.possible_labels[index[1]]
index[0] = self.possible_labels[int(index[0])]
index[1] = self.possible_labels[int(index[1])]
# Getting the indexes of the two clients
indexes = numpy.where(input_labels == index[0])[0]
......
......@@ -89,8 +89,8 @@ def main():
# Preparing the architecture
cnn = True
if cnn:
#architecture = Lenet(seed=SEED)
architecture = Dummy(seed=SEED)
architecture = Lenet(seed=SEED)
#architecture = Dummy(seed=SEED)
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
trainer = Trainer(architecture=architecture, loss=loss, iterations=ITERATIONS)
trainer.train(train_data_shuffler, validation_data_shuffler)
......
......@@ -23,7 +23,7 @@ import tensorflow as tf
from .. import util
SEED = 10
from bob.learn.tensorflow.data import MemoryDataShuffler, TextDataShuffler
from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra
from bob.learn.tensorflow.network import Lenet, MLP, LenetDropout, VGG, Chopra, Dummy
from bob.learn.tensorflow.trainers import SiameseTrainer
from bob.learn.tensorflow.loss import ContrastiveLoss
import numpy
......@@ -39,7 +39,7 @@ def main():
perc_train = 0.9
# Loading data
mnist = True
mnist = False
if mnist:
train_data, train_labels, validation_data, validation_labels = \
......@@ -58,55 +58,70 @@ def main():
batch_size=VALIDATION_BATCH_SIZE)
else:
import bob.db.atnt
db = bob.db.atnt.Database()
import bob.db.mobio
db_mobio = bob.db.mobio.Database()
#import bob.db.mobio
#db = bob.db.mobio.Database()
import bob.db.casia_webface
db_casia = bob.db.casia_webface.Database()
# Preparing train set
#train_objects = db.objects(protocol="male", groups="world")
train_objects = db.objects(groups="world")
train_labels = [o.client_id for o in train_objects]
#directory = "/idiap/user/tpereira/face/baselines/eigenface/preprocessed",
train_objects = db_casia.objects(groups="world")
#train_objects = db.objects(groups="world")
train_labels = [int(o.client_id) for o in train_objects]
directory = "/idiap/resource/database/CASIA-WebFace/CASIA-WebFace"
train_file_names = [o.make_path(
directory="/idiap/group/biometric/databases/orl",
extension=".pgm")
directory=directory,
extension="")
for o in train_objects]
#import ipdb;
#ipdb.set_trace();
#train_file_names = [o.make_path(
# directory="/idiap/group/biometric/databases/orl",
# extension=".pgm")
# for o in train_objects]
#train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
# input_shape=[80, 64, 1],
# batch_size=BATCH_SIZE)
train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
input_shape=[56, 46, 1],
input_shape=[250, 250, 3],
batch_size=BATCH_SIZE)
#train_data_shuffler = TextDataShuffler(train_file_names, train_labels,
# input_shape=[56, 46, 1],
# batch_size=BATCH_SIZE)
# Preparing train set
#validation_objects = db.objects(protocol="male", groups="dev")
validation_objects = db.objects(groups="dev")
directory = "/idiap/temp/tpereira/DEEP_FACE/CASIA/preprocessed"
validation_objects = db_mobio.objects(protocol="male", groups="dev")
validation_labels = [o.client_id for o in validation_objects]
#validation_file_names = [o.make_path(
# directory="/idiap/group/biometric/databases/orl",
# extension=".pgm")
# for o in validation_objects]
validation_file_names = [o.make_path(
directory="/idiap/group/biometric/databases/orl",
extension=".pgm")
directory=directory,
extension=".hdf5")
for o in validation_objects]
#validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
# input_shape=[80, 64, 1],
# batch_size=VALIDATION_BATCH_SIZE)
validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
input_shape=[56, 46, 1],
input_shape=[250, 250, 3],
batch_size=VALIDATION_BATCH_SIZE)
#validation_data_shuffler = TextDataShuffler(validation_file_names, validation_labels,
# input_shape=[56, 46, 1],
# batch_size=VALIDATION_BATCH_SIZE)
# Preparing the architecture
n_classes = len(train_data_shuffler.possible_labels)
n_classes = 200
cnn = True
if cnn:
# LENET PAPER CHOPRA
#architecture = Chopra(default_feature_layer="fc7")
architecture = Lenet(default_feature_layer="fc2", n_classes=n_classes, conv1_output=4, conv2_output=8,use_gpu=USE_GPU)
architecture = Lenet(default_feature_layer="fc2", n_classes=n_classes, conv1_output=8, conv2_output=16,use_gpu=USE_GPU)
#architecture = VGG(n_classes=n_classes, use_gpu=USE_GPU)
#architecture = Dummy(seed=SEED)
#architecture = LenetDropout(default_feature_layer="fc2", n_classes=n_classes, conv1_output=4, conv2_output=8, use_gpu=USE_GPU)
......@@ -115,7 +130,8 @@ def main():
trainer = SiameseTrainer(architecture=architecture,
loss=loss,
iterations=ITERATIONS,
snapshot=VALIDATION_TEST)
snapshot=VALIDATION_TEST,
)
trainer.train(train_data_shuffler, validation_data_shuffler)
else:
mlp = MLP(n_classes, hidden_layers=[15, 20])
......
......@@ -12,7 +12,7 @@ from ..network import SequenceNetwork
import bob.io.base
from .Trainer import Trainer
import os
import sys
class SiameseTrainer(Trainer):
......@@ -64,7 +64,8 @@ class SiameseTrainer(Trainer):
"""
Injecting data in the place holder queue
"""
for i in range(self.iterations):
#for i in range(self.iterations+5):
while not thread_pool.should_stop():
batch_left, batch_right, labels = train_data_shuffler.get_pair()
feed_dict = {train_placeholder_left_data: batch_left,
......@@ -151,13 +152,13 @@ class SiameseTrainer(Trainer):
self.architecture.generate_summaries()
merged_validation = tf.merge_all_summaries()
for step in range(self.iterations):
_, l, lr, summary = session.run([optimizer, loss_train, learning_rate, merged])
#_, l, lr= session.run([optimizer, loss_train, learning_rate])
train_writer.add_summary(summary, step)
print str(step)
sys.stdout.flush()
if validation_data_shuffler is not None and step % self.snapshot == 0:
......@@ -167,7 +168,9 @@ class SiameseTrainer(Trainer):
summary = analizer()
train_writer.add_summary(summary, step)
print str(step)
sys.stdout.flush()
print("#######DONE##########")
self.architecture.save(hdf5)
del hdf5
train_writer.close()
......
......@@ -79,7 +79,8 @@ class Trainer(object):
"""
#while not thread_pool.should_stop():
for i in range(self.iterations):
#for i in range(self.iterations):
while not thread_pool.should_stop():
train_data, train_labels = train_data_shuffler.get_batch()
feed_dict = {train_placeholder_data: train_data,
......
......@@ -5,6 +5,7 @@
[buildout]
parts = scripts
eggs = bob.learn.tensorflow
bob.db.casia_webface
gridtk
extensions = bob.buildout
......@@ -12,6 +13,9 @@ extensions = bob.buildout
auto-checkout = *
develop = src/bob.db.mnist
src/gridtk
src/bob.db.casia_webface
src/bob.db.mobio
src/bob.db.lfw
.
; options for bob.buildout
......@@ -21,7 +25,11 @@ newest = false
[sources]
bob.db.mnist = git git@github.com:tiagofrepereira2012/bob.db.mnist
bob.db.mnist = git git@github.com:tiagofrepereira2012/bob.db.mnist.git
bob.db.base = git git@gitlab.idiap.ch:bob/bob.db.base.git
bob.db.mobio = git git@gitlab.idiap.ch:bob/bob.db.mobio.git
bob.db.lfw = git git@gitlab.idiap.ch:bob/bob.db.lfw.git
bob.db.casia_webface = git git@gitlab.idiap.ch:bob/bob.db.casia_webface.git
gridtk = git git@github.com:bioidiap/gridtk
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment