From 89da1f7080cb0caca357ba35f82d583dfee87420 Mon Sep 17 00:00:00 2001 From: Tiago Pereira <tiago.pereira@partner.samsung.com> Date: Mon, 10 Jul 2017 13:38:14 -0700 Subject: [PATCH] [trainers] Fixed the SiameseTrainer validation --- .../tensorflow/trainers/SiameseTrainer.py | 24 +++++++++++++++++++ bob/learn/tensorflow/trainers/Trainer.py | 4 ++++ 2 files changed, 28 insertions(+) diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py index ff69ec1b..1bb2438f 100644 --- a/bob/learn/tensorflow/trainers/SiameseTrainer.py +++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py @@ -77,6 +77,7 @@ class SiameseTrainer(Trainer): # Validation data self.validation_summary_writter = None + self.summaries_validation = None # Analizer self.analizer = analizer @@ -156,6 +157,9 @@ class SiameseTrainer(Trainer): self.summaries_train = self.create_general_summary() tf.add_to_collection("summaries_train", self.summaries_train) + self.summaries_validation = self.create_general_summary() + tf.add_to_collection("summaries_validation", self.summaries_validation) + # Creating the variables tf.global_variables_initializer().run(session=self.session) @@ -219,3 +223,23 @@ class SiameseTrainer(Trainer): tf.summary.scalar('within_class_loss', self.predictor['within_class']) tf.summary.scalar('lr', self.learning_rate) return tf.summary.merge_all() + + def compute_validation(self, data_shuffler, step): + """ + Computes the loss in the validation set + + ** Parameters ** + session: Tensorflow session + data_shuffler: The data shuffler to be used + step: Iteration number + + """ + # Opening a new session for validation + feed_dict = self.get_feed_dict(data_shuffler) + + l, summary = self.session.run([self.predictor, self.summaries_validation], feed_dict=feed_dict) + self.validation_summary_writter.add_summary(summary, step) + + #summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))] + #self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step) + logger.info("Loss VALIDATION set step={0} = {1}".format(step, l)) diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py index 34ad7749..f7999e1c 100644 --- a/bob/learn/tensorflow/trainers/Trainer.py +++ b/bob/learn/tensorflow/trainers/Trainer.py @@ -81,6 +81,7 @@ class Trainer(object): # Validation data self.validation_summary_writter = None + self.summaries_validation = None # Analizer self.analizer = analizer @@ -160,6 +161,8 @@ class Trainer(object): self.summaries_train = self.create_general_summary() tf.add_to_collection("summaries_train", self.summaries_train) + self.summaries_validation = self.create_general_summary() + self.summaries_validation = tf.add_to_collection("summaries_validation", self.summaries_validation) # Creating the variables tf.global_variables_initializer().run(session=self.session) @@ -186,6 +189,7 @@ class Trainer(object): self.optimizer = tf.get_collection("optimizer")[0] self.learning_rate = tf.get_collection("learning_rate")[0] self.summaries_train = tf.get_collection("summaries_train")[0] + self.summaries_validation = tf.get_collection("summaries_validation")[0] self.global_step = tf.get_collection("global_step")[0] self.from_scratch = False -- GitLab