Commit 754ef62d authored by Tiago Pereira's avatar Tiago Pereira
Browse files

Fixed serialization issues

parent 9950b499
......@@ -51,4 +51,9 @@ class ContrastiveLoss(BaseLoss):
loss = 0.5 * (within_class + between_class)
return tf.reduce_mean(loss), tf.reduce_mean(between_class), tf.reduce_mean(within_class)
loss_dict = dict()
loss_dict['loss'] = tf.reduce_mean(loss)
loss_dict['between_class'] = tf.reduce_mean(between_class)
loss_dict['within_class'] = tf.reduce_mean(within_class)
return loss_dict
......@@ -54,4 +54,9 @@ class TripletLoss(BaseLoss):
basic_loss = tf.add(tf.subtract(d_positive, d_negative), self.margin)
loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0)
return loss, tf.reduce_mean(d_negative), tf.reduce_mean(d_positive)
loss_dict = dict()
loss_dict['loss'] = loss
loss_dict['between_class'] = tf.reduce_mean(d_negative)
loss_dict['within_class'] = tf.reduce_mean(d_positive)
return loss_dict
......@@ -163,7 +163,6 @@ def test_siamesecnn_trainer():
trainer.train(train_data_shuffler)
embedding = Embedding(train_data_shuffler("data", from_queue=False)['left'], graph['left'])
eer = dummy_experiment(validation_data_shuffler, embedding)
assert eer < 0.15
shutil.rmtree(directory)
......
......@@ -145,18 +145,25 @@ class SiameseTrainer(Trainer):
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables())
tf.add_to_collection("global_step", self.global_step)
tf.add_to_collection("graph", self.graph)
tf.add_to_collection("predictor", self.predictor)
# Saving the pointers to the graph
tf.add_to_collection("graph_left", self.graph['left'])
tf.add_to_collection("graph_right", self.graph['right'])
# Saving pointers to the loss
tf.add_to_collection("predictor_loss", self.predictor['loss'])
tf.add_to_collection("predictor_between_class_loss", self.predictor['between_class'])
tf.add_to_collection("predictor_within_class_loss", self.predictor['within_class'])
tf.add_to_collection("data_ph", self.data_ph)
# Saving the pointers to the placeholders
tf.add_to_collection("data_ph_left", self.data_ph['left'])
tf.add_to_collection("data_ph_right", self.data_ph['right'])
tf.add_to_collection("label_ph", self.label_ph)
# Preparing the optimizer
self.optimizer_class._learning_rate = self.learning_rate
self.optimizer = self.optimizer_class.minimize(self.predictor[0], global_step=self.global_step)
self.optimizer = self.optimizer_class.minimize(self.predictor['loss'], global_step=self.global_step)
tf.add_to_collection("optimizer", self.optimizer)
tf.add_to_collection("learning_rate", self.learning_rate)
......@@ -166,6 +173,44 @@ class SiameseTrainer(Trainer):
# Creating the variables
tf.global_variables_initializer().run(session=self.session)
def create_network_from_file(self, model_from_file):
"""
Bootstrap all the necessary data from file
** Parameters **
session: Tensorflow session
train_data_shuffler: Data shuffler for training
validation_data_shuffler: Data shuffler for validation
"""
#saver = self.architecture.load(self.model_from_file, clear_devices=False)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta")
self.saver.restore(self.session, model_from_file)
# Loading the graph from the graph pointers
self.graph = dict()
self.graph['left'] = tf.get_collection("graph_left")[0]
self.graph['right'] = tf.get_collection("graph_right")[0]
# Loading the place holders by the pointer
self.data_ph = dict()
self.data_ph['left'] = tf.get_collection("data_ph_left")[0]
self.data_ph['right'] = tf.get_collection("data_ph_right")[0]
self.label_ph = tf.get_collection("label_ph")[0]
self.predictor = []
self.predictor = tf.get_collection("predictor")[0]
# Loding other elements
self.optimizer = tf.get_collection("optimizer")[0]
self.learning_rate = tf.get_collection("learning_rate")[0]
self.summaries_train = tf.get_collection("summaries_train")[0]
self.global_step = tf.get_collection("global_step")[0]
self.from_scratch = False
def get_feed_dict(self, data_shuffler):
"""
Given a data shuffler prepared the dictionary to be injected in the graph
......@@ -194,8 +239,8 @@ class SiameseTrainer(Trainer):
feed_dict = self.get_feed_dict(self.train_data_shuffler)
_, l, bt_class, wt_class, lr, summary = self.session.run([
self.optimizer,
self.predictor[0], self.predictor[1],
self.predictor[2],
self.predictor['loss'], self.predictor['between_class'],
self.predictor['within_class'],
self.learning_rate, self.summaries_train], feed_dict=feed_dict)
logger.info("Loss training set step={0} = {1}".format(step, l))
......@@ -207,8 +252,8 @@ class SiameseTrainer(Trainer):
"""
# Train summary
tf.summary.scalar('loss', self.predictor[0])
tf.summary.scalar('between_class_loss', self.predictor[1])
tf.summary.scalar('within_class_loss', self.predictor[2])
tf.summary.scalar('loss', self.predictor['loss'])
tf.summary.scalar('between_class_loss', self.predictor['between_class'])
tf.summary.scalar('within_class_loss', self.predictor['within_class'])
tf.summary.scalar('lr', self.learning_rate)
return tf.summary.merge_all()
......@@ -165,14 +165,24 @@ class TripletTrainer(Trainer):
tf.add_to_collection("global_step", self.global_step)
tf.add_to_collection("graph", self.graph)
tf.add_to_collection("predictor", self.predictor)
# Saving the pointers to the graph
tf.add_to_collection("graph_anchor", self.graph['anchor'])
tf.add_to_collection("graph_positive", self.graph['positive'])
tf.add_to_collection("graph_negative", self.graph['negative'])
tf.add_to_collection("data_ph", self.data_ph)
# Saving pointers to the loss
tf.add_to_collection("predictor_loss", self.predictor['loss'])
tf.add_to_collection("predictor_between_class_loss", self.predictor['between_class'])
tf.add_to_collection("predictor_within_class_loss", self.predictor['within_class'])
# Saving the pointers to the placeholders
tf.add_to_collection("data_ph_anchor", self.data_ph['anchor'])
tf.add_to_collection("data_ph_positive", self.data_ph['positive'])
tf.add_to_collection("data_ph_negative", self.data_ph['negative'])
# Preparing the optimizer
self.optimizer_class._learning_rate = self.learning_rate
self.optimizer = self.optimizer_class.minimize(self.predictor[0], global_step=self.global_step)
self.optimizer = self.optimizer_class.minimize(self.predictor['loss'], global_step=self.global_step)
tf.add_to_collection("optimizer", self.optimizer)
tf.add_to_collection("learning_rate", self.learning_rate)
......@@ -210,8 +220,8 @@ class TripletTrainer(Trainer):
feed_dict = self.get_feed_dict(self.train_data_shuffler)
_, l, bt_class, wt_class, lr, summary = self.session.run([
self.optimizer,
self.predictor[0], self.predictor[1],
self.predictor[2],
self.predictor['loss'], self.predictor['between_class'],
self.predictor['within_class'],
self.learning_rate, self.summaries_train], feed_dict=feed_dict)
logger.info("Loss training set step={0} = {1}".format(step, l))
......@@ -224,9 +234,9 @@ class TripletTrainer(Trainer):
"""
# Train summary
tf.summary.scalar('loss', self.predictor[0])
tf.summary.scalar('between_class_loss', self.predictor[1])
tf.summary.scalar('within_class_loss', self.predictor[2])
tf.summary.scalar('loss', self.predictor['loss'])
tf.summary.scalar('between_class_loss', self.predictor['between_class'])
tf.summary.scalar('within_class_loss', self.predictor['within_class'])
tf.summary.scalar('lr', self.learning_rate)
return tf.summary.merge_all()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment