Commit 89da1f70 authored by Tiago Pereira's avatar Tiago Pereira
Browse files

[trainers] Fixed the SiameseTrainer validation

parent 66948f20
Pipeline #11232 passed with stages
in 11 minutes and 48 seconds
......@@ -77,6 +77,7 @@ class SiameseTrainer(Trainer):
# Validation data
self.validation_summary_writter = None
self.summaries_validation = None
# Analizer
self.analizer = analizer
......@@ -156,6 +157,9 @@ class SiameseTrainer(Trainer):
self.summaries_train = self.create_general_summary()
tf.add_to_collection("summaries_train", self.summaries_train)
self.summaries_validation = self.create_general_summary()
tf.add_to_collection("summaries_validation", self.summaries_validation)
# Creating the variables
tf.global_variables_initializer().run(session=self.session)
......@@ -219,3 +223,23 @@ class SiameseTrainer(Trainer):
tf.summary.scalar('within_class_loss', self.predictor['within_class'])
tf.summary.scalar('lr', self.learning_rate)
return tf.summary.merge_all()
def compute_validation(self, data_shuffler, step):
"""
Computes the loss in the validation set
** Parameters **
session: Tensorflow session
data_shuffler: The data shuffler to be used
step: Iteration number
"""
# Opening a new session for validation
feed_dict = self.get_feed_dict(data_shuffler)
l, summary = self.session.run([self.predictor, self.summaries_validation], feed_dict=feed_dict)
self.validation_summary_writter.add_summary(summary, step)
#summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
#self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
......@@ -81,6 +81,7 @@ class Trainer(object):
# Validation data
self.validation_summary_writter = None
self.summaries_validation = None
# Analizer
self.analizer = analizer
......@@ -160,6 +161,8 @@ class Trainer(object):
self.summaries_train = self.create_general_summary()
tf.add_to_collection("summaries_train", self.summaries_train)
self.summaries_validation = self.create_general_summary()
self.summaries_validation = tf.add_to_collection("summaries_validation", self.summaries_validation)
# Creating the variables
tf.global_variables_initializer().run(session=self.session)
......@@ -186,6 +189,7 @@ class Trainer(object):
self.optimizer = tf.get_collection("optimizer")[0]
self.learning_rate = tf.get_collection("learning_rate")[0]
self.summaries_train = tf.get_collection("summaries_train")[0]
self.summaries_validation = tf.get_collection("summaries_validation")[0]
self.global_step = tf.get_collection("global_step")[0]
self.from_scratch = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment