From 2e5982a7133543d7de595c4c73587c4db40098af Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Sat, 23 Sep 2017 14:00:01 +0200 Subject: [PATCH] Improved the loading from the last checkpoint --- bob/learn/tensorflow/script/train.py | 21 +++++------ .../test/data/train_scripts/siamese.py | 10 +++--- .../test/data/train_scripts/triplet.py | 11 +++--- .../tensorflow/test/test_train_script.py | 31 ++++++++++------ .../tensorflow/trainers/SiameseTrainer.py | 5 +-- bob/learn/tensorflow/trainers/Trainer.py | 36 ++++++++++++------- .../tensorflow/trainers/TripletTrainer.py | 7 ++-- 7 files changed, 67 insertions(+), 54 deletions(-) diff --git a/bob/learn/tensorflow/script/train.py b/bob/learn/tensorflow/script/train.py index 015f77a7..c9cbfed1 100644 --- a/bob/learn/tensorflow/script/train.py +++ b/bob/learn/tensorflow/script/train.py @@ -7,50 +7,45 @@ Train a Neural network using bob.learn.tensorflow Usage: - train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> --pretrained-net=<arg> --use-gpu --prefetch ] <configuration> + train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> ] <configuration> train.py -h | --help Options: -h --help Show this screen. --iterations=<arg> [default: 1000] --validation-interval=<arg> [default: 100] --output-dir=<arg> If the directory exists, will try to get the last checkpoint [default: ./logs/] - --pretrained-net=<arg> """ - from docopt import docopt import imp import bob.learn.tensorflow import tensorflow as tf import os + def main(): args = docopt(__doc__, version='Train Neural Net') - USE_GPU = args['--use-gpu'] OUTPUT_DIR = str(args['--output-dir']) - PREFETCH = args['--prefetch'] ITERATIONS = int(args['--iterations']) - PRETRAINED_NET = "" - if not args['--pretrained-net'] is None: - PRETRAINED_NET = str(args['--pretrained-net']) - + #PRETRAINED_NET = "" + #if not args['--pretrained-net'] is None: + # PRETRAINED_NET = str(args['--pretrained-net']) config = imp.load_source('config', args['<configuration>']) + # Cleaning all variables in case you are loading the checkpoint + tf.reset_default_graph() if os.path.exists(OUTPUT_DIR) else None + # One graph trainer trainer = config.Trainer(config.train_data_shuffler, iterations=ITERATIONS, analizer=None, temp_dir=OUTPUT_DIR) - - if os.path.exists(OUTPUT_DIR): print("Directory already exists, trying to get the last checkpoint") - import ipdb; ipdb.set_trace(); trainer.create_network_from_file(OUTPUT_DIR) - else: # Preparing the architecture diff --git a/bob/learn/tensorflow/test/data/train_scripts/siamese.py b/bob/learn/tensorflow/test/data/train_scripts/siamese.py index 0b65ec26..d8b0de1f 100644 --- a/bob/learn/tensorflow/test/data/train_scripts/siamese.py +++ b/bob/learn/tensorflow/test/data/train_scripts/siamese.py @@ -22,16 +22,16 @@ train_data_shuffler = SiameseMemory(train_data, train_labels, normalizer=ScaleFactor()) ### ARCHITECTURE ### -architecture = Chopra(seed=SEED, fc1_output=10, batch_norm=False) +architecture = Chopra(seed=SEED, n_classes=10) ### LOSS ### loss = ContrastiveLoss(contrastive_margin=4.) -### SOLVER ### -optimizer = tf.train.GradientDescentOptimizer(0.001) - ### LEARNING RATE ### -learning_rate = constant(base_learning_rate=0.001) +learning_rate = constant(base_learning_rate=0.01) + +### SOLVER ### +optimizer = tf.train.GradientDescentOptimizer(learning_rate) ### Trainer ### trainer = Trainer diff --git a/bob/learn/tensorflow/test/data/train_scripts/triplet.py b/bob/learn/tensorflow/test/data/train_scripts/triplet.py index 26ca494a..3ef79480 100644 --- a/bob/learn/tensorflow/test/data/train_scripts/triplet.py +++ b/bob/learn/tensorflow/test/data/train_scripts/triplet.py @@ -21,16 +21,19 @@ train_data_shuffler = TripletMemory(train_data, train_labels, batch_size=BATCH_SIZE) ### ARCHITECTURE ### -architecture = Chopra(seed=SEED, fc1_output=10, batch_norm=False) +architecture = Chopra(seed=SEED, n_classes=10) ### LOSS ### loss = TripletLoss(margin=4.) -### SOLVER ### -optimizer = tf.train.GradientDescentOptimizer(0.001) ### LEARNING RATE ### -learning_rate = constant(base_learning_rate=0.001) +learning_rate = constant(base_learning_rate=0.01) + + +### SOLVER ### +optimizer = tf.train.GradientDescentOptimizer(learning_rate) + ### Trainer ### trainer = Trainer diff --git a/bob/learn/tensorflow/test/test_train_script.py b/bob/learn/tensorflow/test/test_train_script.py index 86ea0f4d..3f5ca992 100644 --- a/bob/learn/tensorflow/test/test_train_script.py +++ b/bob/learn/tensorflow/test/test_train_script.py @@ -10,22 +10,29 @@ import shutil def test_train_script_softmax(): directory = "./temp/train-script" train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/softmax.py') - train_script = './data/train_scripts/softmax.py' from subprocess import call + # Start the training call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script]) - #shutil.rmtree(directory) + + # Continuing from the last checkpoint + call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script]) + shutil.rmtree(directory) assert True def test_train_script_triplet(): directory = "./temp/train-script" train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/triplet.py') - #train_script = './data/train_scripts/triplet.py' - #from subprocess import call - #call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script]) - #shutil.rmtree(directory) + from subprocess import call + # Start the training + call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script]) + + # Continuing from the last checkpoint + call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script]) + + shutil.rmtree(directory) assert True @@ -33,10 +40,14 @@ def test_train_script_triplet(): def test_train_script_siamese(): directory = "./temp/train-script" train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/siamese.py') - #train_script = './data/train_scripts/siamese.py' - #from subprocess import call - #call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script]) - #shutil.rmtree(directory) + from subprocess import call + # Start the training + call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script]) + + # Continuing from the last checkpoint + call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script]) + + shutil.rmtree(directory) assert True diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py index b8f7bcca..27e4a31b 100644 --- a/bob/learn/tensorflow/trainers/SiameseTrainer.py +++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py @@ -179,9 +179,7 @@ class SiameseTrainer(Trainer): def create_network_from_file(self, model_from_file, clear_devices=True): - #saver = self.architecture.load(self.model_from_file, clear_devices=False) - self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices) - self.saver.restore(self.session, model_from_file) + self.load_checkpoint(model_from_file, clear_devices=clear_devices) # Loading the graph from the graph pointers self.graph = dict() @@ -206,7 +204,6 @@ class SiameseTrainer(Trainer): self.summaries_train = tf.get_collection("summaries_train")[0] self.global_step = tf.get_collection("global_step")[0] self.from_scratch = False - def get_feed_dict(self, data_shuffler): diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py index a19bb9e0..77dd1d5a 100644 --- a/bob/learn/tensorflow/trainers/Trainer.py +++ b/bob/learn/tensorflow/trainers/Trainer.py @@ -122,7 +122,6 @@ class Trainer(object): self.session = Session.instance(new=True).session self.from_scratch = True - def train(self): """ Train the network @@ -197,7 +196,6 @@ class Trainer(object): #if not isinstance(self.train_data_shuffler, TFRecord): # self.thread_pool.join(threads) - def create_network_from_scratch(self, graph, validation_graph=None, @@ -222,9 +220,6 @@ class Trainer(object): learning_rate: Learning rate """ - # Putting together the training data + graph + loss - - # Getting the pointer to the placeholders self.data_ph = self.train_data_shuffler("data", from_queue=True) self.label_ph = self.train_data_shuffler("label", from_queue=True) @@ -243,7 +238,6 @@ class Trainer(object): self.optimizer_class._learning_rate = self.learning_rate self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step) - # Saving all the variables self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(), keep_checkpoint_every_n_hours=self.keep_checkpoint_every_n_hours) @@ -264,7 +258,7 @@ class Trainer(object): tf.add_to_collection("summaries_train", self.summaries_train) # Same business with the validation - if(self.validation_data_shuffler is not None): + if self.validation_data_shuffler is not None: self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True) self.validation_label_ph = self.validation_data_shuffler("label", from_queue=True) @@ -286,6 +280,24 @@ class Trainer(object): tf.local_variables_initializer().run(session=self.session) tf.global_variables_initializer().run(session=self.session) + def load_checkpoint(self, file_name, clear_devices=True): + """ + Load a checkpoint + + ** Parameters ** + + file_name: + Name of the metafile to be loaded. + If a directory is passed, the last checkpoint will be loaded + + """ + if os.path.isdir(file_name): + checkpoint_path = tf.train.get_checkpoint_state(file_name).model_checkpoint_path + self.saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=clear_devices) + self.saver.restore(self.session, tf.train.latest_checkpoint(file_name)) + else: + self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices) + self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name))) def create_network_from_file(self, file_name, clear_devices=True): """ @@ -295,9 +307,9 @@ class Trainer(object): file_name: Name of of the checkpoing """ - #self.saver = tf.train.import_meta_graph(file_name + ".meta", clear_devices=clear_devices) - self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices) - self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name))) + + logger.info("Loading last checkpoint !!") + self.load_checkpoint(file_name, clear_devices=True) # Loading training graph self.data_ph = tf.get_collection("data_ph")[0] @@ -314,10 +326,9 @@ class Trainer(object): self.from_scratch = False # Loading the validation bits - if(self.validation_data_shuffler is not None): + if self.validation_data_shuffler is not None: self.summaries_validation = tf.get_collection("summaries_validation")[0] - self.validation_graph = tf.get_collection("validation_graph")[0] self.validation_data_ph = tf.get_collection("validation_data_ph")[0] self.validation_label = tf.get_collection("validation_label_ph")[0] @@ -325,7 +336,6 @@ class Trainer(object): self.validation_predictor = tf.get_collection("validation_predictor")[0] self.summaries_validation = tf.get_collection("summaries_validation")[0] - def __del__(self): tf.reset_default_graph() diff --git a/bob/learn/tensorflow/trainers/TripletTrainer.py b/bob/learn/tensorflow/trainers/TripletTrainer.py index a941c8b1..6dbc5624 100644 --- a/bob/learn/tensorflow/trainers/TripletTrainer.py +++ b/bob/learn/tensorflow/trainers/TripletTrainer.py @@ -120,7 +120,6 @@ class TripletTrainer(Trainer): self.session = Session.instance(new=True).session self.from_scratch = True - def create_network_from_scratch(self, graph, optimizer=tf.train.AdamOptimizer(), @@ -177,11 +176,9 @@ class TripletTrainer(Trainer): # Creating the variables tf.global_variables_initializer().run(session=self.session) - def create_network_from_file(self, model_from_file, clear_devices=True): + def create_network_from_file(self, file_name, clear_devices=True): - #saver = self.architecture.load(self.model_from_file, clear_devices=False) - self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices) - self.saver.restore(self.session, model_from_file) + self.load_checkpoint(file_name, clear_devices=clear_devices) # Loading the graph from the graph pointers self.graph = dict() -- GitLab