From 2e5982a7133543d7de595c4c73587c4db40098af Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Sat, 23 Sep 2017 14:00:01 +0200
Subject: [PATCH] Improved the loading from the last checkpoint

---
 bob/learn/tensorflow/script/train.py          | 21 +++++------
 .../test/data/train_scripts/siamese.py        | 10 +++---
 .../test/data/train_scripts/triplet.py        | 11 +++---
 .../tensorflow/test/test_train_script.py      | 31 ++++++++++------
 .../tensorflow/trainers/SiameseTrainer.py     |  5 +--
 bob/learn/tensorflow/trainers/Trainer.py      | 36 ++++++++++++-------
 .../tensorflow/trainers/TripletTrainer.py     |  7 ++--
 7 files changed, 67 insertions(+), 54 deletions(-)

diff --git a/bob/learn/tensorflow/script/train.py b/bob/learn/tensorflow/script/train.py
index 015f77a7..c9cbfed1 100644
--- a/bob/learn/tensorflow/script/train.py
+++ b/bob/learn/tensorflow/script/train.py
@@ -7,50 +7,45 @@
 Train a Neural network using bob.learn.tensorflow
 
 Usage:
-  train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> --pretrained-net=<arg> --use-gpu --prefetch ] <configuration>
+  train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> ] <configuration>
   train.py -h | --help
 Options:
   -h --help     Show this screen.
   --iterations=<arg>   [default: 1000]
   --validation-interval=<arg>   [default: 100]
   --output-dir=<arg>    If the directory exists, will try to get the last checkpoint [default: ./logs/]
-  --pretrained-net=<arg>
 """
 
-
 from docopt import docopt
 import imp
 import bob.learn.tensorflow
 import tensorflow as tf
 import os
 
+
 def main():
     args = docopt(__doc__, version='Train Neural Net')
 
-    USE_GPU = args['--use-gpu']
     OUTPUT_DIR = str(args['--output-dir'])
-    PREFETCH = args['--prefetch']
     ITERATIONS = int(args['--iterations'])
 
-    PRETRAINED_NET = ""
-    if not args['--pretrained-net'] is None:
-        PRETRAINED_NET = str(args['--pretrained-net'])
-
+    #PRETRAINED_NET = ""
+    #if not args['--pretrained-net'] is None:
+    #    PRETRAINED_NET = str(args['--pretrained-net'])
 
     config = imp.load_source('config', args['<configuration>'])
 
+    # Cleaning all variables in case you are loading the checkpoint
+    tf.reset_default_graph() if os.path.exists(OUTPUT_DIR) else None
+
     # One graph trainer
     trainer = config.Trainer(config.train_data_shuffler,
                              iterations=ITERATIONS,
                              analizer=None,
                              temp_dir=OUTPUT_DIR)
-
-
     if os.path.exists(OUTPUT_DIR):
         print("Directory already exists, trying to get the last checkpoint")
-        import ipdb; ipdb.set_trace();
         trainer.create_network_from_file(OUTPUT_DIR)
-
     else:
 
         # Preparing the architecture
diff --git a/bob/learn/tensorflow/test/data/train_scripts/siamese.py b/bob/learn/tensorflow/test/data/train_scripts/siamese.py
index 0b65ec26..d8b0de1f 100644
--- a/bob/learn/tensorflow/test/data/train_scripts/siamese.py
+++ b/bob/learn/tensorflow/test/data/train_scripts/siamese.py
@@ -22,16 +22,16 @@ train_data_shuffler = SiameseMemory(train_data, train_labels,
                                     normalizer=ScaleFactor())
 
 ### ARCHITECTURE ###
-architecture = Chopra(seed=SEED, fc1_output=10, batch_norm=False)
+architecture = Chopra(seed=SEED, n_classes=10)
 
 ### LOSS ###
 loss = ContrastiveLoss(contrastive_margin=4.)
 
-### SOLVER ###
-optimizer = tf.train.GradientDescentOptimizer(0.001)
-
 ### LEARNING RATE ###
-learning_rate = constant(base_learning_rate=0.001)
+learning_rate = constant(base_learning_rate=0.01)
+
+### SOLVER ###
+optimizer = tf.train.GradientDescentOptimizer(learning_rate)
 
 ### Trainer ###
 trainer = Trainer
diff --git a/bob/learn/tensorflow/test/data/train_scripts/triplet.py b/bob/learn/tensorflow/test/data/train_scripts/triplet.py
index 26ca494a..3ef79480 100644
--- a/bob/learn/tensorflow/test/data/train_scripts/triplet.py
+++ b/bob/learn/tensorflow/test/data/train_scripts/triplet.py
@@ -21,16 +21,19 @@ train_data_shuffler = TripletMemory(train_data, train_labels,
                                     batch_size=BATCH_SIZE)
 
 ### ARCHITECTURE ###
-architecture = Chopra(seed=SEED, fc1_output=10, batch_norm=False)
+architecture = Chopra(seed=SEED, n_classes=10)
 
 ### LOSS ###
 loss = TripletLoss(margin=4.)
 
-### SOLVER ###
-optimizer = tf.train.GradientDescentOptimizer(0.001)
 
 ### LEARNING RATE ###
-learning_rate = constant(base_learning_rate=0.001)
+learning_rate = constant(base_learning_rate=0.01)
+
+
+### SOLVER ###
+optimizer = tf.train.GradientDescentOptimizer(learning_rate)
+
 
 ### Trainer ###
 trainer = Trainer
diff --git a/bob/learn/tensorflow/test/test_train_script.py b/bob/learn/tensorflow/test/test_train_script.py
index 86ea0f4d..3f5ca992 100644
--- a/bob/learn/tensorflow/test/test_train_script.py
+++ b/bob/learn/tensorflow/test/test_train_script.py
@@ -10,22 +10,29 @@ import shutil
 def test_train_script_softmax():
     directory = "./temp/train-script"
     train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/softmax.py')
-    train_script = './data/train_scripts/softmax.py'
 
     from subprocess import call
+    # Start the training
     call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
-    #shutil.rmtree(directory)
+
+    # Continuing from the last checkpoint
+    call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
+    shutil.rmtree(directory)
     assert True
 
 
 def test_train_script_triplet():
     directory = "./temp/train-script"
     train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/triplet.py')
-    #train_script = './data/train_scripts/triplet.py'
 
-    #from subprocess import call
-    #call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
-    #shutil.rmtree(directory)
+    from subprocess import call
+    # Start the training
+    call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
+
+    # Continuing from the last checkpoint
+    call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
+
+    shutil.rmtree(directory)
 
     assert True
 
@@ -33,10 +40,14 @@ def test_train_script_triplet():
 def test_train_script_siamese():
     directory = "./temp/train-script"
     train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/siamese.py')
-    #train_script = './data/train_scripts/siamese.py'
 
-    #from subprocess import call
-    #call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
-    #shutil.rmtree(directory)
+    from subprocess import call
+    # Start the training
+    call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
+
+    # Continuing from the last checkpoint
+    call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
+
+    shutil.rmtree(directory)
 
     assert True
diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py
index b8f7bcca..27e4a31b 100644
--- a/bob/learn/tensorflow/trainers/SiameseTrainer.py
+++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py
@@ -179,9 +179,7 @@ class SiameseTrainer(Trainer):
 
     def create_network_from_file(self, model_from_file, clear_devices=True):
 
-        #saver = self.architecture.load(self.model_from_file, clear_devices=False)
-        self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices)
-        self.saver.restore(self.session, model_from_file)
+        self.load_checkpoint(model_from_file, clear_devices=clear_devices)
 
         # Loading the graph from the graph pointers
         self.graph = dict()
@@ -206,7 +204,6 @@ class SiameseTrainer(Trainer):
         self.summaries_train = tf.get_collection("summaries_train")[0]
         self.global_step = tf.get_collection("global_step")[0]
         self.from_scratch = False
-        
 
     def get_feed_dict(self, data_shuffler):
 
diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py
index a19bb9e0..77dd1d5a 100644
--- a/bob/learn/tensorflow/trainers/Trainer.py
+++ b/bob/learn/tensorflow/trainers/Trainer.py
@@ -122,7 +122,6 @@ class Trainer(object):
         self.session = Session.instance(new=True).session
         self.from_scratch = True
         
-        
     def train(self):
         """
         Train the network        
@@ -197,7 +196,6 @@ class Trainer(object):
             #if not isinstance(self.train_data_shuffler, TFRecord):
             #    self.thread_pool.join(threads)        
 
-
     def create_network_from_scratch(self,
                                     graph,
                                     validation_graph=None,
@@ -222,9 +220,6 @@ class Trainer(object):
             learning_rate: Learning rate
         """
 
-       # Putting together the training data + graph  + loss
-
-
         # Getting the pointer to the placeholders
         self.data_ph = self.train_data_shuffler("data", from_queue=True)
         self.label_ph = self.train_data_shuffler("label", from_queue=True)
@@ -243,7 +238,6 @@ class Trainer(object):
         self.optimizer_class._learning_rate = self.learning_rate
         self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
 
-
         # Saving all the variables
         self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(), 
                                     keep_checkpoint_every_n_hours=self.keep_checkpoint_every_n_hours)
@@ -264,7 +258,7 @@ class Trainer(object):
         tf.add_to_collection("summaries_train", self.summaries_train)
 
         # Same business with the validation
-        if(self.validation_data_shuffler is not None):
+        if self.validation_data_shuffler is not None:
             self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True)
             self.validation_label_ph = self.validation_data_shuffler("label", from_queue=True)
 
@@ -286,6 +280,24 @@ class Trainer(object):
         tf.local_variables_initializer().run(session=self.session)
         tf.global_variables_initializer().run(session=self.session)
 
+    def load_checkpoint(self, file_name, clear_devices=True):
+        """
+        Load a checkpoint
+
+        ** Parameters **
+
+           file_name:
+                Name of the metafile to be loaded.
+                If a directory is passed, the last checkpoint will be loaded
+
+        """
+        if os.path.isdir(file_name):
+            checkpoint_path = tf.train.get_checkpoint_state(file_name).model_checkpoint_path
+            self.saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=clear_devices)
+            self.saver.restore(self.session, tf.train.latest_checkpoint(file_name))
+        else:
+            self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
+            self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
 
     def create_network_from_file(self, file_name, clear_devices=True):
         """
@@ -295,9 +307,9 @@ class Trainer(object):
 
            file_name: Name of of the checkpoing
         """
-        #self.saver = tf.train.import_meta_graph(file_name + ".meta", clear_devices=clear_devices)
-        self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)        
-        self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
+
+        logger.info("Loading last checkpoint !!")
+        self.load_checkpoint(file_name, clear_devices=True)
 
         # Loading training graph
         self.data_ph = tf.get_collection("data_ph")[0]
@@ -314,10 +326,9 @@ class Trainer(object):
         self.from_scratch = False
         
         # Loading the validation bits
-        if(self.validation_data_shuffler is not None):
+        if self.validation_data_shuffler is not None:
             self.summaries_validation = tf.get_collection("summaries_validation")[0]
 
-
             self.validation_graph = tf.get_collection("validation_graph")[0]
             self.validation_data_ph = tf.get_collection("validation_data_ph")[0]
             self.validation_label = tf.get_collection("validation_label_ph")[0]
@@ -325,7 +336,6 @@ class Trainer(object):
             self.validation_predictor = tf.get_collection("validation_predictor")[0]
             self.summaries_validation = tf.get_collection("summaries_validation")[0]
 
-
     def __del__(self):
         tf.reset_default_graph()
 
diff --git a/bob/learn/tensorflow/trainers/TripletTrainer.py b/bob/learn/tensorflow/trainers/TripletTrainer.py
index a941c8b1..6dbc5624 100644
--- a/bob/learn/tensorflow/trainers/TripletTrainer.py
+++ b/bob/learn/tensorflow/trainers/TripletTrainer.py
@@ -120,7 +120,6 @@ class TripletTrainer(Trainer):
         self.session = Session.instance(new=True).session
         self.from_scratch = True
 
-
     def create_network_from_scratch(self,
                                     graph,
                                     optimizer=tf.train.AdamOptimizer(),
@@ -177,11 +176,9 @@ class TripletTrainer(Trainer):
         # Creating the variables
         tf.global_variables_initializer().run(session=self.session)
 
-    def create_network_from_file(self, model_from_file, clear_devices=True):
+    def create_network_from_file(self, file_name, clear_devices=True):
 
-        #saver = self.architecture.load(self.model_from_file, clear_devices=False)
-        self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices)
-        self.saver.restore(self.session, model_from_file)
+        self.load_checkpoint(file_name, clear_devices=clear_devices)
 
         # Loading the graph from the graph pointers
         self.graph = dict()
-- 
GitLab