Improved the loading from the last checkpoint

parent e6341fae
Pipeline #12496 failed with stages
in 29 minutes and 27 seconds
...@@ -7,50 +7,45 @@ ...@@ -7,50 +7,45 @@
Train a Neural network using bob.learn.tensorflow Train a Neural network using bob.learn.tensorflow
Usage: Usage:
train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> --pretrained-net=<arg> --use-gpu --prefetch ] <configuration> train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> ] <configuration>
train.py -h | --help train.py -h | --help
Options: Options:
-h --help Show this screen. -h --help Show this screen.
--iterations=<arg> [default: 1000] --iterations=<arg> [default: 1000]
--validation-interval=<arg> [default: 100] --validation-interval=<arg> [default: 100]
--output-dir=<arg> If the directory exists, will try to get the last checkpoint [default: ./logs/] --output-dir=<arg> If the directory exists, will try to get the last checkpoint [default: ./logs/]
--pretrained-net=<arg>
""" """
from docopt import docopt from docopt import docopt
import imp import imp
import bob.learn.tensorflow import bob.learn.tensorflow
import tensorflow as tf import tensorflow as tf
import os import os
def main(): def main():
args = docopt(__doc__, version='Train Neural Net') args = docopt(__doc__, version='Train Neural Net')
USE_GPU = args['--use-gpu']
OUTPUT_DIR = str(args['--output-dir']) OUTPUT_DIR = str(args['--output-dir'])
PREFETCH = args['--prefetch']
ITERATIONS = int(args['--iterations']) ITERATIONS = int(args['--iterations'])
PRETRAINED_NET = "" #PRETRAINED_NET = ""
if not args['--pretrained-net'] is None: #if not args['--pretrained-net'] is None:
PRETRAINED_NET = str(args['--pretrained-net']) # PRETRAINED_NET = str(args['--pretrained-net'])
config = imp.load_source('config', args['<configuration>']) config = imp.load_source('config', args['<configuration>'])
# Cleaning all variables in case you are loading the checkpoint
tf.reset_default_graph() if os.path.exists(OUTPUT_DIR) else None
# One graph trainer # One graph trainer
trainer = config.Trainer(config.train_data_shuffler, trainer = config.Trainer(config.train_data_shuffler,
iterations=ITERATIONS, iterations=ITERATIONS,
analizer=None, analizer=None,
temp_dir=OUTPUT_DIR) temp_dir=OUTPUT_DIR)
if os.path.exists(OUTPUT_DIR): if os.path.exists(OUTPUT_DIR):
print("Directory already exists, trying to get the last checkpoint") print("Directory already exists, trying to get the last checkpoint")
import ipdb; ipdb.set_trace();
trainer.create_network_from_file(OUTPUT_DIR) trainer.create_network_from_file(OUTPUT_DIR)
else: else:
# Preparing the architecture # Preparing the architecture
......
...@@ -22,16 +22,16 @@ train_data_shuffler = SiameseMemory(train_data, train_labels, ...@@ -22,16 +22,16 @@ train_data_shuffler = SiameseMemory(train_data, train_labels,
normalizer=ScaleFactor()) normalizer=ScaleFactor())
### ARCHITECTURE ### ### ARCHITECTURE ###
architecture = Chopra(seed=SEED, fc1_output=10, batch_norm=False) architecture = Chopra(seed=SEED, n_classes=10)
### LOSS ### ### LOSS ###
loss = ContrastiveLoss(contrastive_margin=4.) loss = ContrastiveLoss(contrastive_margin=4.)
### SOLVER ###
optimizer = tf.train.GradientDescentOptimizer(0.001)
### LEARNING RATE ### ### LEARNING RATE ###
learning_rate = constant(base_learning_rate=0.001) learning_rate = constant(base_learning_rate=0.01)
### SOLVER ###
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
### Trainer ### ### Trainer ###
trainer = Trainer trainer = Trainer
...@@ -21,16 +21,19 @@ train_data_shuffler = TripletMemory(train_data, train_labels, ...@@ -21,16 +21,19 @@ train_data_shuffler = TripletMemory(train_data, train_labels,
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
### ARCHITECTURE ### ### ARCHITECTURE ###
architecture = Chopra(seed=SEED, fc1_output=10, batch_norm=False) architecture = Chopra(seed=SEED, n_classes=10)
### LOSS ### ### LOSS ###
loss = TripletLoss(margin=4.) loss = TripletLoss(margin=4.)
### SOLVER ###
optimizer = tf.train.GradientDescentOptimizer(0.001)
### LEARNING RATE ### ### LEARNING RATE ###
learning_rate = constant(base_learning_rate=0.001) learning_rate = constant(base_learning_rate=0.01)
### SOLVER ###
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
### Trainer ### ### Trainer ###
trainer = Trainer trainer = Trainer
...@@ -10,22 +10,29 @@ import shutil ...@@ -10,22 +10,29 @@ import shutil
def test_train_script_softmax(): def test_train_script_softmax():
directory = "./temp/train-script" directory = "./temp/train-script"
train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/softmax.py') train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/softmax.py')
train_script = './data/train_scripts/softmax.py'
from subprocess import call from subprocess import call
# Start the training
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script]) call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
#shutil.rmtree(directory)
# Continuing from the last checkpoint
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
shutil.rmtree(directory)
assert True assert True
def test_train_script_triplet(): def test_train_script_triplet():
directory = "./temp/train-script" directory = "./temp/train-script"
train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/triplet.py') train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/triplet.py')
#train_script = './data/train_scripts/triplet.py'
#from subprocess import call from subprocess import call
#call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script]) # Start the training
#shutil.rmtree(directory) call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
# Continuing from the last checkpoint
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
shutil.rmtree(directory)
assert True assert True
...@@ -33,10 +40,14 @@ def test_train_script_triplet(): ...@@ -33,10 +40,14 @@ def test_train_script_triplet():
def test_train_script_siamese(): def test_train_script_siamese():
directory = "./temp/train-script" directory = "./temp/train-script"
train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/siamese.py') train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/siamese.py')
#train_script = './data/train_scripts/siamese.py'
#from subprocess import call from subprocess import call
#call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script]) # Start the training
#shutil.rmtree(directory) call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
# Continuing from the last checkpoint
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
shutil.rmtree(directory)
assert True assert True
...@@ -179,9 +179,7 @@ class SiameseTrainer(Trainer): ...@@ -179,9 +179,7 @@ class SiameseTrainer(Trainer):
def create_network_from_file(self, model_from_file, clear_devices=True): def create_network_from_file(self, model_from_file, clear_devices=True):
#saver = self.architecture.load(self.model_from_file, clear_devices=False) self.load_checkpoint(model_from_file, clear_devices=clear_devices)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, model_from_file)
# Loading the graph from the graph pointers # Loading the graph from the graph pointers
self.graph = dict() self.graph = dict()
...@@ -206,7 +204,6 @@ class SiameseTrainer(Trainer): ...@@ -206,7 +204,6 @@ class SiameseTrainer(Trainer):
self.summaries_train = tf.get_collection("summaries_train")[0] self.summaries_train = tf.get_collection("summaries_train")[0]
self.global_step = tf.get_collection("global_step")[0] self.global_step = tf.get_collection("global_step")[0]
self.from_scratch = False self.from_scratch = False
def get_feed_dict(self, data_shuffler): def get_feed_dict(self, data_shuffler):
......
...@@ -122,7 +122,6 @@ class Trainer(object): ...@@ -122,7 +122,6 @@ class Trainer(object):
self.session = Session.instance(new=True).session self.session = Session.instance(new=True).session
self.from_scratch = True self.from_scratch = True
def train(self): def train(self):
""" """
Train the network Train the network
...@@ -197,7 +196,6 @@ class Trainer(object): ...@@ -197,7 +196,6 @@ class Trainer(object):
#if not isinstance(self.train_data_shuffler, TFRecord): #if not isinstance(self.train_data_shuffler, TFRecord):
# self.thread_pool.join(threads) # self.thread_pool.join(threads)
def create_network_from_scratch(self, def create_network_from_scratch(self,
graph, graph,
validation_graph=None, validation_graph=None,
...@@ -222,9 +220,6 @@ class Trainer(object): ...@@ -222,9 +220,6 @@ class Trainer(object):
learning_rate: Learning rate learning_rate: Learning rate
""" """
# Putting together the training data + graph + loss
# Getting the pointer to the placeholders # Getting the pointer to the placeholders
self.data_ph = self.train_data_shuffler("data", from_queue=True) self.data_ph = self.train_data_shuffler("data", from_queue=True)
self.label_ph = self.train_data_shuffler("label", from_queue=True) self.label_ph = self.train_data_shuffler("label", from_queue=True)
...@@ -243,7 +238,6 @@ class Trainer(object): ...@@ -243,7 +238,6 @@ class Trainer(object):
self.optimizer_class._learning_rate = self.learning_rate self.optimizer_class._learning_rate = self.learning_rate
self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step) self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
# Saving all the variables # Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(), self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(),
keep_checkpoint_every_n_hours=self.keep_checkpoint_every_n_hours) keep_checkpoint_every_n_hours=self.keep_checkpoint_every_n_hours)
...@@ -264,7 +258,7 @@ class Trainer(object): ...@@ -264,7 +258,7 @@ class Trainer(object):
tf.add_to_collection("summaries_train", self.summaries_train) tf.add_to_collection("summaries_train", self.summaries_train)
# Same business with the validation # Same business with the validation
if(self.validation_data_shuffler is not None): if self.validation_data_shuffler is not None:
self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True) self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True)
self.validation_label_ph = self.validation_data_shuffler("label", from_queue=True) self.validation_label_ph = self.validation_data_shuffler("label", from_queue=True)
...@@ -286,6 +280,24 @@ class Trainer(object): ...@@ -286,6 +280,24 @@ class Trainer(object):
tf.local_variables_initializer().run(session=self.session) tf.local_variables_initializer().run(session=self.session)
tf.global_variables_initializer().run(session=self.session) tf.global_variables_initializer().run(session=self.session)
def load_checkpoint(self, file_name, clear_devices=True):
"""
Load a checkpoint
** Parameters **
file_name:
Name of the metafile to be loaded.
If a directory is passed, the last checkpoint will be loaded
"""
if os.path.isdir(file_name):
checkpoint_path = tf.train.get_checkpoint_state(file_name).model_checkpoint_path
self.saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, tf.train.latest_checkpoint(file_name))
else:
self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
def create_network_from_file(self, file_name, clear_devices=True): def create_network_from_file(self, file_name, clear_devices=True):
""" """
...@@ -295,9 +307,9 @@ class Trainer(object): ...@@ -295,9 +307,9 @@ class Trainer(object):
file_name: Name of of the checkpoing file_name: Name of of the checkpoing
""" """
#self.saver = tf.train.import_meta_graph(file_name + ".meta", clear_devices=clear_devices)
self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices) logger.info("Loading last checkpoint !!")
self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name))) self.load_checkpoint(file_name, clear_devices=True)
# Loading training graph # Loading training graph
self.data_ph = tf.get_collection("data_ph")[0] self.data_ph = tf.get_collection("data_ph")[0]
...@@ -314,10 +326,9 @@ class Trainer(object): ...@@ -314,10 +326,9 @@ class Trainer(object):
self.from_scratch = False self.from_scratch = False
# Loading the validation bits # Loading the validation bits
if(self.validation_data_shuffler is not None): if self.validation_data_shuffler is not None:
self.summaries_validation = tf.get_collection("summaries_validation")[0] self.summaries_validation = tf.get_collection("summaries_validation")[0]
self.validation_graph = tf.get_collection("validation_graph")[0] self.validation_graph = tf.get_collection("validation_graph")[0]
self.validation_data_ph = tf.get_collection("validation_data_ph")[0] self.validation_data_ph = tf.get_collection("validation_data_ph")[0]
self.validation_label = tf.get_collection("validation_label_ph")[0] self.validation_label = tf.get_collection("validation_label_ph")[0]
...@@ -325,7 +336,6 @@ class Trainer(object): ...@@ -325,7 +336,6 @@ class Trainer(object):
self.validation_predictor = tf.get_collection("validation_predictor")[0] self.validation_predictor = tf.get_collection("validation_predictor")[0]
self.summaries_validation = tf.get_collection("summaries_validation")[0] self.summaries_validation = tf.get_collection("summaries_validation")[0]
def __del__(self): def __del__(self):
tf.reset_default_graph() tf.reset_default_graph()
......
...@@ -120,7 +120,6 @@ class TripletTrainer(Trainer): ...@@ -120,7 +120,6 @@ class TripletTrainer(Trainer):
self.session = Session.instance(new=True).session self.session = Session.instance(new=True).session
self.from_scratch = True self.from_scratch = True
def create_network_from_scratch(self, def create_network_from_scratch(self,
graph, graph,
optimizer=tf.train.AdamOptimizer(), optimizer=tf.train.AdamOptimizer(),
...@@ -177,11 +176,9 @@ class TripletTrainer(Trainer): ...@@ -177,11 +176,9 @@ class TripletTrainer(Trainer):
# Creating the variables # Creating the variables
tf.global_variables_initializer().run(session=self.session) tf.global_variables_initializer().run(session=self.session)
def create_network_from_file(self, model_from_file, clear_devices=True): def create_network_from_file(self, file_name, clear_devices=True):
#saver = self.architecture.load(self.model_from_file, clear_devices=False) self.load_checkpoint(file_name, clear_devices=clear_devices)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, model_from_file)
# Loading the graph from the graph pointers # Loading the graph from the graph pointers
self.graph = dict() self.graph = dict()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment