Improved the loading from the last checkpoint

parent e6341fae
Pipeline #12496 failed with stages
in 29 minutes and 27 seconds
......@@ -7,50 +7,45 @@
Train a Neural network using bob.learn.tensorflow
Usage:
train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> --pretrained-net=<arg> --use-gpu --prefetch ] <configuration>
train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> ] <configuration>
train.py -h | --help
Options:
-h --help Show this screen.
--iterations=<arg> [default: 1000]
--validation-interval=<arg> [default: 100]
--output-dir=<arg> If the directory exists, will try to get the last checkpoint [default: ./logs/]
--pretrained-net=<arg>
"""
from docopt import docopt
import imp
import bob.learn.tensorflow
import tensorflow as tf
import os
def main():
args = docopt(__doc__, version='Train Neural Net')
USE_GPU = args['--use-gpu']
OUTPUT_DIR = str(args['--output-dir'])
PREFETCH = args['--prefetch']
ITERATIONS = int(args['--iterations'])
PRETRAINED_NET = ""
if not args['--pretrained-net'] is None:
PRETRAINED_NET = str(args['--pretrained-net'])
#PRETRAINED_NET = ""
#if not args['--pretrained-net'] is None:
# PRETRAINED_NET = str(args['--pretrained-net'])
config = imp.load_source('config', args['<configuration>'])
# Cleaning all variables in case you are loading the checkpoint
tf.reset_default_graph() if os.path.exists(OUTPUT_DIR) else None
# One graph trainer
trainer = config.Trainer(config.train_data_shuffler,
iterations=ITERATIONS,
analizer=None,
temp_dir=OUTPUT_DIR)
if os.path.exists(OUTPUT_DIR):
print("Directory already exists, trying to get the last checkpoint")
import ipdb; ipdb.set_trace();
trainer.create_network_from_file(OUTPUT_DIR)
else:
# Preparing the architecture
......
......@@ -22,16 +22,16 @@ train_data_shuffler = SiameseMemory(train_data, train_labels,
normalizer=ScaleFactor())
### ARCHITECTURE ###
architecture = Chopra(seed=SEED, fc1_output=10, batch_norm=False)
architecture = Chopra(seed=SEED, n_classes=10)
### LOSS ###
loss = ContrastiveLoss(contrastive_margin=4.)
### SOLVER ###
optimizer = tf.train.GradientDescentOptimizer(0.001)
### LEARNING RATE ###
learning_rate = constant(base_learning_rate=0.001)
learning_rate = constant(base_learning_rate=0.01)
### SOLVER ###
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
### Trainer ###
trainer = Trainer
......@@ -21,16 +21,19 @@ train_data_shuffler = TripletMemory(train_data, train_labels,
batch_size=BATCH_SIZE)
### ARCHITECTURE ###
architecture = Chopra(seed=SEED, fc1_output=10, batch_norm=False)
architecture = Chopra(seed=SEED, n_classes=10)
### LOSS ###
loss = TripletLoss(margin=4.)
### SOLVER ###
optimizer = tf.train.GradientDescentOptimizer(0.001)
### LEARNING RATE ###
learning_rate = constant(base_learning_rate=0.001)
learning_rate = constant(base_learning_rate=0.01)
### SOLVER ###
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
### Trainer ###
trainer = Trainer
......@@ -10,22 +10,29 @@ import shutil
def test_train_script_softmax():
directory = "./temp/train-script"
train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/softmax.py')
train_script = './data/train_scripts/softmax.py'
from subprocess import call
# Start the training
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
#shutil.rmtree(directory)
# Continuing from the last checkpoint
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
shutil.rmtree(directory)
assert True
def test_train_script_triplet():
directory = "./temp/train-script"
train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/triplet.py')
#train_script = './data/train_scripts/triplet.py'
#from subprocess import call
#call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
#shutil.rmtree(directory)
from subprocess import call
# Start the training
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
# Continuing from the last checkpoint
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
shutil.rmtree(directory)
assert True
......@@ -33,10 +40,14 @@ def test_train_script_triplet():
def test_train_script_siamese():
directory = "./temp/train-script"
train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/siamese.py')
#train_script = './data/train_scripts/siamese.py'
#from subprocess import call
#call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
#shutil.rmtree(directory)
from subprocess import call
# Start the training
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
# Continuing from the last checkpoint
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
shutil.rmtree(directory)
assert True
......@@ -179,9 +179,7 @@ class SiameseTrainer(Trainer):
def create_network_from_file(self, model_from_file, clear_devices=True):
#saver = self.architecture.load(self.model_from_file, clear_devices=False)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, model_from_file)
self.load_checkpoint(model_from_file, clear_devices=clear_devices)
# Loading the graph from the graph pointers
self.graph = dict()
......@@ -206,7 +204,6 @@ class SiameseTrainer(Trainer):
self.summaries_train = tf.get_collection("summaries_train")[0]
self.global_step = tf.get_collection("global_step")[0]
self.from_scratch = False
def get_feed_dict(self, data_shuffler):
......
......@@ -122,7 +122,6 @@ class Trainer(object):
self.session = Session.instance(new=True).session
self.from_scratch = True
def train(self):
"""
Train the network
......@@ -197,7 +196,6 @@ class Trainer(object):
#if not isinstance(self.train_data_shuffler, TFRecord):
# self.thread_pool.join(threads)
def create_network_from_scratch(self,
graph,
validation_graph=None,
......@@ -222,9 +220,6 @@ class Trainer(object):
learning_rate: Learning rate
"""
# Putting together the training data + graph + loss
# Getting the pointer to the placeholders
self.data_ph = self.train_data_shuffler("data", from_queue=True)
self.label_ph = self.train_data_shuffler("label", from_queue=True)
......@@ -243,7 +238,6 @@ class Trainer(object):
self.optimizer_class._learning_rate = self.learning_rate
self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(),
keep_checkpoint_every_n_hours=self.keep_checkpoint_every_n_hours)
......@@ -264,7 +258,7 @@ class Trainer(object):
tf.add_to_collection("summaries_train", self.summaries_train)
# Same business with the validation
if(self.validation_data_shuffler is not None):
if self.validation_data_shuffler is not None:
self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True)
self.validation_label_ph = self.validation_data_shuffler("label", from_queue=True)
......@@ -286,6 +280,24 @@ class Trainer(object):
tf.local_variables_initializer().run(session=self.session)
tf.global_variables_initializer().run(session=self.session)
def load_checkpoint(self, file_name, clear_devices=True):
"""
Load a checkpoint
** Parameters **
file_name:
Name of the metafile to be loaded.
If a directory is passed, the last checkpoint will be loaded
"""
if os.path.isdir(file_name):
checkpoint_path = tf.train.get_checkpoint_state(file_name).model_checkpoint_path
self.saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, tf.train.latest_checkpoint(file_name))
else:
self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
def create_network_from_file(self, file_name, clear_devices=True):
"""
......@@ -295,9 +307,9 @@ class Trainer(object):
file_name: Name of of the checkpoing
"""
#self.saver = tf.train.import_meta_graph(file_name + ".meta", clear_devices=clear_devices)
self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
logger.info("Loading last checkpoint !!")
self.load_checkpoint(file_name, clear_devices=True)
# Loading training graph
self.data_ph = tf.get_collection("data_ph")[0]
......@@ -314,10 +326,9 @@ class Trainer(object):
self.from_scratch = False
# Loading the validation bits
if(self.validation_data_shuffler is not None):
if self.validation_data_shuffler is not None:
self.summaries_validation = tf.get_collection("summaries_validation")[0]
self.validation_graph = tf.get_collection("validation_graph")[0]
self.validation_data_ph = tf.get_collection("validation_data_ph")[0]
self.validation_label = tf.get_collection("validation_label_ph")[0]
......@@ -325,7 +336,6 @@ class Trainer(object):
self.validation_predictor = tf.get_collection("validation_predictor")[0]
self.summaries_validation = tf.get_collection("summaries_validation")[0]
def __del__(self):
tf.reset_default_graph()
......
......@@ -120,7 +120,6 @@ class TripletTrainer(Trainer):
self.session = Session.instance(new=True).session
self.from_scratch = True
def create_network_from_scratch(self,
graph,
optimizer=tf.train.AdamOptimizer(),
......@@ -177,11 +176,9 @@ class TripletTrainer(Trainer):
# Creating the variables
tf.global_variables_initializer().run(session=self.session)
def create_network_from_file(self, model_from_file, clear_devices=True):
def create_network_from_file(self, file_name, clear_devices=True):
#saver = self.architecture.load(self.model_from_file, clear_devices=False)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, model_from_file)
self.load_checkpoint(file_name, clear_devices=clear_devices)
# Loading the graph from the graph pointers
self.graph = dict()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment