From 89da1f7080cb0caca357ba35f82d583dfee87420 Mon Sep 17 00:00:00 2001
From: Tiago Pereira <tiago.pereira@partner.samsung.com>
Date: Mon, 10 Jul 2017 13:38:14 -0700
Subject: [PATCH] [trainers] Fixed the SiameseTrainer validation

---
 .../tensorflow/trainers/SiameseTrainer.py     | 24 +++++++++++++++++++
 bob/learn/tensorflow/trainers/Trainer.py      |  4 ++++
 2 files changed, 28 insertions(+)

diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py
index ff69ec1b..1bb2438f 100644
--- a/bob/learn/tensorflow/trainers/SiameseTrainer.py
+++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py
@@ -77,6 +77,7 @@ class SiameseTrainer(Trainer):
 
         # Validation data
         self.validation_summary_writter = None
+        self.summaries_validation = None
 
         # Analizer
         self.analizer = analizer
@@ -156,6 +157,9 @@ class SiameseTrainer(Trainer):
         self.summaries_train = self.create_general_summary()
         tf.add_to_collection("summaries_train", self.summaries_train)
 
+        self.summaries_validation = self.create_general_summary()
+        tf.add_to_collection("summaries_validation", self.summaries_validation)
+
         # Creating the variables
         tf.global_variables_initializer().run(session=self.session)
 
@@ -219,3 +223,23 @@ class SiameseTrainer(Trainer):
         tf.summary.scalar('within_class_loss', self.predictor['within_class'])
         tf.summary.scalar('lr', self.learning_rate)
         return tf.summary.merge_all()
+
+    def compute_validation(self, data_shuffler, step):
+        """
+        Computes the loss in the validation set
+
+        ** Parameters **
+            session: Tensorflow session
+            data_shuffler: The data shuffler to be used
+            step: Iteration number
+
+        """
+        # Opening a new session for validation
+        feed_dict = self.get_feed_dict(data_shuffler)
+
+        l, summary = self.session.run([self.predictor, self.summaries_validation], feed_dict=feed_dict)
+        self.validation_summary_writter.add_summary(summary, step)
+
+        #summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
+        #self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
+        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py
index 34ad7749..f7999e1c 100644
--- a/bob/learn/tensorflow/trainers/Trainer.py
+++ b/bob/learn/tensorflow/trainers/Trainer.py
@@ -81,6 +81,7 @@ class Trainer(object):
 
         # Validation data
         self.validation_summary_writter = None
+        self.summaries_validation = None
 
         # Analizer
         self.analizer = analizer
@@ -160,6 +161,8 @@ class Trainer(object):
         self.summaries_train = self.create_general_summary()
         tf.add_to_collection("summaries_train", self.summaries_train)
 
+        self.summaries_validation = self.create_general_summary()
+        self.summaries_validation = tf.add_to_collection("summaries_validation", self.summaries_validation)
 
         # Creating the variables
         tf.global_variables_initializer().run(session=self.session)
@@ -186,6 +189,7 @@ class Trainer(object):
         self.optimizer = tf.get_collection("optimizer")[0]
         self.learning_rate = tf.get_collection("learning_rate")[0]
         self.summaries_train = tf.get_collection("summaries_train")[0]
+        self.summaries_validation = tf.get_collection("summaries_validation")[0]
         self.global_step = tf.get_collection("global_step")[0]
         self.from_scratch = False
 
-- 
GitLab