From 615b9059cd37a65048bb3461fa60b462f1edf30b Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Fri, 28 Oct 2016 10:37:28 +0200
Subject: [PATCH] Implemented a trainer that lods a pretrained network

---
 .../tensorflow/network/SequenceNetwork.py     | 41 +++++++-----
 bob/learn/tensorflow/test/test_cnn_scratch.py | 67 +++++++++++--------
 .../tensorflow/trainers/SiameseTrainer.py     |  4 ++
 bob/learn/tensorflow/trainers/Trainer.py      | 11 +++
 .../tensorflow/trainers/TripletTrainer.py     |  3 +
 5 files changed, 80 insertions(+), 46 deletions(-)

diff --git a/bob/learn/tensorflow/network/SequenceNetwork.py b/bob/learn/tensorflow/network/SequenceNetwork.py
index 8c7ecf95..595a9283 100644
--- a/bob/learn/tensorflow/network/SequenceNetwork.py
+++ b/bob/learn/tensorflow/network/SequenceNetwork.py
@@ -225,17 +225,32 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
             self.sequence_net[k].weights_initialization.use_gpu = state
             self.sequence_net[k].bias_initialization.use_gpu = state
 
-    def load(self, hdf5, shape=None, session=None, batch=1):
+    def load_variables_only(self, hdf5, session):
         """
-        Load the network
+        Load the variables of the model
+        """
+        hdf5.cd('/tensor_flow')
+        for k in self.sequence_net:
+            # TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
+            if not isinstance(self.sequence_net[k], MaxPooling):
+                self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name)).eval(session=session)
+                session.run(self.sequence_net[k].W)
+                self.sequence_net[k].b.assign(hdf5.read(self.sequence_net[k].b.name)).eval(session=session)
+                session.run(self.sequence_net[k].b)
+        hdf5.cd("..")
+
+    def load(self, hdf5, shape=None, session=None, batch=1, use_gpu=False):
+        """
+        Load the network from scratch.
+        This will build the graphs
 
         **Parameters**
 
             hdf5: The saved network in the :py:class:`bob.io.base.HDF5File` format
-
-            shape: Input shape of the network
-
-            session: tensorflow `session <https://www.tensorflow.org/versions/r0.11/api_docs/python/client.html#Session>`_
+            shape: Input shape of the network. If `None`, the internal shape will be assumed
+            session: An opened tensorflow `session <https://www.tensorflow.org/versions/r0.11/api_docs/python/client.html#Session>`_. If `None`, a new one will be opened
+            batch: The size of the batch
+            use_gpu: Load all the variables in the GPU?
         """
 
         if session is None:
@@ -249,7 +264,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
         self.sequence_net = pickle.loads(hdf5.read('architecture'))
         self.deployment_shape = hdf5.read('deployment_shape')
 
-        self.turn_gpu_onoff(False)
+        self.turn_gpu_onoff(use_gpu)
 
         if shape is None:
             shape = self.deployment_shape
@@ -259,17 +274,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
         place_holder = tf.placeholder(tf.float32, shape=shape, name="load")
         self.compute_graph(place_holder)
         tf.initialize_all_variables().run(session=session)
-
-        hdf5.cd('/tensor_flow')
-        for k in self.sequence_net:
-            # TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
-            if not isinstance(self.sequence_net[k], MaxPooling):
-                #self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name))
-                self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name)).eval(session=session)
-                session.run(self.sequence_net[k].W)
-                self.sequence_net[k].b.assign(hdf5.read(self.sequence_net[k].b.name)).eval(session=session)
-                session.run(self.sequence_net[k].b)
-
+        self.load_variables_only(hdf5, session)
 
     """
     def save(self, session, path, step=None):
diff --git a/bob/learn/tensorflow/test/test_cnn_scratch.py b/bob/learn/tensorflow/test/test_cnn_scratch.py
index 469c3def..66b813f8 100644
--- a/bob/learn/tensorflow/test/test_cnn_scratch.py
+++ b/bob/learn/tensorflow/test/test_cnn_scratch.py
@@ -26,20 +26,7 @@ iterations = 50
 seed = 10
 
 
-def test_cnn_trainer_scratch():
-    train_data, train_labels, validation_data, validation_labels = load_mnist()
-    train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
-
-    # Creating datashufflers
-    data_augmentation = ImageAugmentation()
-    train_data_shuffler = Memory(train_data, train_labels,
-                                 input_shape=[28, 28, 1],
-                                 batch_size=batch_size,
-                                 data_augmentation=data_augmentation)
-    validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
-
-    directory = "./temp/cnn"
-
+def scratch_network():
     # Creating a random network
     scratch = SequenceNetwork()
     scratch.add(Conv2D(name="conv1", kernel_size=3,
@@ -51,20 +38,11 @@ def test_cnn_trainer_scratch():
                                activation=None,
                                weights_initialization=Xavier(seed=seed, use_gpu=False),
                                bias_initialization=Constant(use_gpu=False)))
-    # Loss for the softmax
-    loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
 
-    # One graph trainer
-    trainer = Trainer(architecture=scratch,
-                      loss=loss,
-                      iterations=iterations,
-                      analizer=None,
-                      prefetch=False,
-                      temp_dir=directory)
-    trainer.train(train_data_shuffler)
+    return scratch
 
-    del scratch
 
+def validate_network(validation_data, validation_labels, directory):
     # Testing
     validation_data_shuffler = Memory(validation_data, validation_labels,
                                       input_shape=[28, 28, 1],
@@ -79,7 +57,40 @@ def test_cnn_trainer_scratch():
         predictions = scratch(data, session=session)
         accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == labels) / predictions.shape[0]
 
-        # At least 80% of accuracy
-        assert accuracy > 80.
-        shutil.rmtree(directory)
+    return accuracy
+
+
+def test_cnn_trainer_scratch():
+    train_data, train_labels, validation_data, validation_labels = load_mnist()
+    train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
+
+    # Creating datashufflers
+    data_augmentation = ImageAugmentation()
+    train_data_shuffler = Memory(train_data, train_labels,
+                                 input_shape=[28, 28, 1],
+                                 batch_size=batch_size,
+                                 data_augmentation=data_augmentation)
+    validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
+
+    directory = "./temp/cnn"
+
+    # Create scratch network
+    scratch = scratch_network()
+
+    # Loss for the softmax
+    loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
+
+    # One graph trainer
+    trainer = Trainer(architecture=scratch,
+                      loss=loss,
+                      iterations=iterations,
+                      analizer=None,
+                      prefetch=False,
+                      temp_dir=directory)
+    trainer.train(train_data_shuffler)
+
+    accuracy = validate_network(validation_data, validation_labels, directory)
+    assert accuracy > 80
+    del scratch
+
 
diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py
index c7cd48d9..7cf1f206 100644
--- a/bob/learn/tensorflow/trainers/SiameseTrainer.py
+++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py
@@ -62,6 +62,8 @@ class SiameseTrainer(Trainer):
                  ## Analizer
                  analizer=ExperimentAnalizer(),
 
+                 model_from_file="",
+
                  verbosity_level=2):
 
         super(SiameseTrainer, self).__init__(
@@ -86,6 +88,8 @@ class SiameseTrainer(Trainer):
             ## Analizer
             analizer=analizer,
 
+            model_from_file=model_from_file,
+
             verbosity_level=verbosity_level
         )
 
diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py
index 6cce163c..68538236 100644
--- a/bob/learn/tensorflow/trainers/Trainer.py
+++ b/bob/learn/tensorflow/trainers/Trainer.py
@@ -65,6 +65,9 @@ class Trainer(object):
                  ## Analizer
                  analizer=SoftmaxAnalizer(),
 
+                 ### Pretrained model
+                 model_from_file="",
+
                  verbosity_level=2):
 
         if not isinstance(architecture, SequenceNetwork):
@@ -107,6 +110,8 @@ class Trainer(object):
         self.enqueue_op = None
         self.global_step = None
 
+        self.model_from_file = model_from_file
+
         bob.core.log.set_verbosity_level(logger, verbosity_level)
 
     def __del__(self):
@@ -289,6 +294,12 @@ class Trainer(object):
         with tf.Session(config=config) as session:
             tf.initialize_all_variables().run()
 
+            # Loading a pretrained model
+            if self.model_from_file != "":
+                logger.info("Loading pretrained model from {0}".format(self.model_from_file))
+                hdf5 = bob.io.base.HDF5File(self.model_from_file)
+                self.architecture.load_variables_only(hdf5, session)
+
             if isinstance(train_data_shuffler, OnLineSampling):
                 train_data_shuffler.set_feature_extractor(self.architecture, session=session)
 
diff --git a/bob/learn/tensorflow/trainers/TripletTrainer.py b/bob/learn/tensorflow/trainers/TripletTrainer.py
index 2379510f..9452d5c8 100644
--- a/bob/learn/tensorflow/trainers/TripletTrainer.py
+++ b/bob/learn/tensorflow/trainers/TripletTrainer.py
@@ -62,6 +62,8 @@ class TripletTrainer(Trainer):
                  ## Analizer
                  analizer=ExperimentAnalizer(),
 
+                 model_from_file="",
+
                  verbosity_level=2):
 
         super(TripletTrainer, self).__init__(
@@ -85,6 +87,7 @@ class TripletTrainer(Trainer):
 
             ## Analizer
             analizer=analizer,
+            model_from_file=model_from_file,
 
             verbosity_level=verbosity_level
         )
-- 
GitLab