Commit 9f6a3721 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira

Merge branch 'updates' into 'master'

gridtk integration

See merge request !14
parents 6a4603ce 22df299b
Pipeline #12612 passed with stages
in 19 minutes and 38 seconds
......@@ -7,58 +7,103 @@
Train a Neural network using bob.learn.tensorflow
Usage:
train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> --pretrained-net=<arg> --use-gpu --prefetch ] <configuration>
train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> ] <configuration> [grid --n-jobs=<arg> --job-name=<job-name> --queue=<arg>]
train.py -h | --help
Options:
-h --help Show this screen.
--iterations=<arg> [default: 1000]
--validation-interval=<arg> [default: 100]
--output-dir=<arg> [default: ./logs/]
--pretrained-net=<arg>
-h --help Show this screen.
--iterations=<arg> Number of iteratiosn [default: 1000]
--validation-interval=<arg> Validata every n iteratiosn [default: 500]
--output-dir=<arg> If the directory exists, will try to get the last checkpoint [default: ./logs/]
--n-jobs=<arg> Number of jobs submitted to the grid [default: 3]
--job-name=<arg> Job name [default: TF]
--queue=<arg> SGE queue name [default: q_gpu]
"""
from docopt import docopt
import imp
import bob.learn.tensorflow
import tensorflow as tf
import os
import sys
import logging
logger = logging.getLogger("bob.learn")
def dump_commandline():
command_line = []
for command in sys.argv:
if command == "grid":
break
command_line.append(command)
return command_line
def main():
args = docopt(__doc__, version='Train Neural Net')
output_dir = str(args['--output-dir'])
iterations = int(args['--iterations'])
USE_GPU = args['--use-gpu']
OUTPUT_DIR = str(args['--output-dir'])
PREFETCH = args['--prefetch']
ITERATIONS = int(args['--iterations'])
grid = int(args['grid'])
if grid:
# Submitting jobs to SGE
jobs = int(args['--n-jobs'])
job_name = args['--job-name']
queue = args['--queue']
import gridtk
PRETRAINED_NET = ""
if not args['--pretrained-net'] is None:
PRETRAINED_NET = str(args['--pretrained-net'])
job_manager = gridtk.sge.JobManagerSGE()
command = dump_commandline()
dependencies = []
total_jobs = []
kwargs = {"env": ["LD_LIBRARY_PATH=/idiap/user/tpereira/cuda/cuda-8.0/lib64:/idiap/user/tpereira/cuda/cudnn-8.0-linux-x64-v5.1/lib64:/idiap/user/tpereira/cuda/cuda-8.0/bin"]}
for i in range(jobs):
job_id = job_manager.submit(command, queue=queue, dependencies=dependencies,
name=job_name + "{0}".format(i), **kwargs)
dependencies = [job_id]
total_jobs.append(job_id)
logger.info("Submitted the jobs {0}".format(total_jobs))
return True
config = imp.load_source('config', args['<configuration>'])
# Cleaning all variables in case you are loading the checkpoint
tf.reset_default_graph() if os.path.exists(output_dir) else None
# One graph trainer
trainer = config.Trainer(config.train_data_shuffler,
iterations=ITERATIONS,
iterations=iterations,
analizer=None,
temp_dir=OUTPUT_DIR)
# Preparing the architecture
input_pl = config.train_data_shuffler("data", from_queue=False)
if isinstance(trainer, bob.learn.tensorflow.trainers.SiameseTrainer):
graph = dict()
graph['left'] = config.architecture(input_pl['left'])
graph['right'] = config.architecture(input_pl['right'], reuse=True)
elif isinstance(trainer, bob.learn.tensorflow.trainers.TripletTrainer):
graph = dict()
graph['anchor'] = config.architecture(input_pl['anchor'])
graph['positive'] = config.architecture(input_pl['positive'], reuse=True)
graph['negative'] = config.architecture(input_pl['negative'], reuse=True)
temp_dir=output_dir)
if os.path.exists(output_dir):
logger.info("Directory already exists, trying to get the last checkpoint")
trainer.create_network_from_file(output_dir)
else:
graph = config.architecture(input_pl)
trainer.create_network_from_scratch(graph, loss=config.loss,
learning_rate=config.learning_rate,
optimizer=config.optimizer)
trainer.train(config.train_data_shuffler)
# Preparing the architecture
input_pl = config.train_data_shuffler("data", from_queue=False)
if isinstance(trainer, bob.learn.tensorflow.trainers.SiameseTrainer):
graph = dict()
graph['left'] = config.architecture(input_pl['left'])
graph['right'] = config.architecture(input_pl['right'], reuse=True)
elif isinstance(trainer, bob.learn.tensorflow.trainers.TripletTrainer):
graph = dict()
graph['anchor'] = config.architecture(input_pl['anchor'])
graph['positive'] = config.architecture(input_pl['positive'], reuse=True)
graph['negative'] = config.architecture(input_pl['negative'], reuse=True)
else:
graph = config.architecture(input_pl)
trainer.create_network_from_scratch(graph, loss=config.loss,
learning_rate=config.learning_rate,
optimizer=config.optimizer)
trainer.train()
......@@ -22,16 +22,16 @@ train_data_shuffler = SiameseMemory(train_data, train_labels,
normalizer=ScaleFactor())
### ARCHITECTURE ###
architecture = Chopra(seed=SEED, fc1_output=10, batch_norm=False)
architecture = Chopra(seed=SEED, n_classes=10)
### LOSS ###
loss = ContrastiveLoss(contrastive_margin=4.)
### SOLVER ###
optimizer = tf.train.GradientDescentOptimizer(0.001)
### LEARNING RATE ###
learning_rate = constant(base_learning_rate=0.001)
learning_rate = constant(base_learning_rate=0.01)
### SOLVER ###
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
### Trainer ###
trainer = Trainer
from bob.learn.tensorflow.datashuffler import Memory
from bob.learn.tensorflow.datashuffler import Memory, ScaleFactor
from bob.learn.tensorflow.network import Chopra
from bob.learn.tensorflow.trainers import Trainer, constant
from bob.learn.tensorflow.loss import BaseLoss
from bob.learn.tensorflow.loss import MeanSoftMaxLoss
from bob.learn.tensorflow.utils import load_mnist
import tensorflow as tf
import numpy
......@@ -18,19 +18,23 @@ train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
train_data_shuffler = Memory(train_data, train_labels,
input_shape=INPUT_SHAPE,
batch_size=BATCH_SIZE)
batch_size=BATCH_SIZE,
normalizer=ScaleFactor())
### ARCHITECTURE ###
architecture = Chopra(seed=SEED, fc1_output=10, batch_norm=False)
architecture = Chopra(seed=SEED, n_classes=10)
### LOSS ###
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
### SOLVER ###
optimizer = tf.train.GradientDescentOptimizer(0.001)
loss = MeanSoftMaxLoss()
### LEARNING RATE ###
learning_rate = constant(base_learning_rate=0.001)
learning_rate = constant(base_learning_rate=0.01)
### SOLVER ###
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
### Trainer ###
trainer = Trainer
......@@ -21,16 +21,19 @@ train_data_shuffler = TripletMemory(train_data, train_labels,
batch_size=BATCH_SIZE)
### ARCHITECTURE ###
architecture = Chopra(seed=SEED, fc1_output=10, batch_norm=False)
architecture = Chopra(seed=SEED, n_classes=10)
### LOSS ###
loss = TripletLoss(margin=4.)
### SOLVER ###
optimizer = tf.train.GradientDescentOptimizer(0.001)
### LEARNING RATE ###
learning_rate = constant(base_learning_rate=0.001)
learning_rate = constant(base_learning_rate=0.01)
### SOLVER ###
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
### Trainer ###
trainer = Trainer
......@@ -173,7 +173,7 @@ def test_triplet_cnn_pretrained():
analizer=None,
temp_dir=directory)
trainer.create_network_from_file(os.path.join(directory, "model.ckp"))
trainer.create_network_from_file(os.path.join(directory, "model.ckp.meta"))
trainer.train()
embedding = Embedding(trainer.data_ph['anchor'], trainer.graph['anchor'])
......@@ -189,6 +189,7 @@ def test_triplet_cnn_pretrained():
def test_siamese_cnn_pretrained():
tf.reset_default_graph()
train_data, train_labels, validation_data, validation_labels = load_mnist()
......@@ -244,7 +245,7 @@ def test_siamese_cnn_pretrained():
analizer=None,
temp_dir=directory)
trainer.create_network_from_file(os.path.join(directory, "model.ckp"))
trainer.create_network_from_file(os.path.join(directory, "model.ckp.meta"))
trainer.train()
#embedding = Embedding(train_data_shuffler("data", from_queue=False)['left'], trainer.graph['left'])
......
......@@ -5,38 +5,60 @@
import pkg_resources
import shutil
import tensorflow as tf
def test_train_script_softmax():
tf.reset_default_graph()
directory = "./temp/train-script"
train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/softmax.py')
#train_script = './data/train_scripts/softmax.py'
#from subprocess import call
#call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
#shutil.rmtree(directory)
assert True
from subprocess import call
# Start the training
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
# Continuing from the last checkpoint
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
shutil.rmtree(directory)
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
def test_train_script_triplet():
tf.reset_default_graph()
directory = "./temp/train-script"
train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/triplet.py')
#train_script = './data/train_scripts/triplet.py'
#from subprocess import call
#call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
#shutil.rmtree(directory)
from subprocess import call
# Start the training
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
# Continuing from the last checkpoint
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
assert True
shutil.rmtree(directory)
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
def test_train_script_siamese():
tf.reset_default_graph()
directory = "./temp/train-script"
train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/siamese.py')
#train_script = './data/train_scripts/siamese.py'
#from subprocess import call
#call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
#shutil.rmtree(directory)
from subprocess import call
# Start the training
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
# Continuing from the last checkpoint
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
shutil.rmtree(directory)
assert True
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
......@@ -179,9 +179,7 @@ class SiameseTrainer(Trainer):
def create_network_from_file(self, model_from_file, clear_devices=True):
#saver = self.architecture.load(self.model_from_file, clear_devices=False)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, model_from_file)
self.load_checkpoint(model_from_file, clear_devices=clear_devices)
# Loading the graph from the graph pointers
self.graph = dict()
......@@ -206,7 +204,6 @@ class SiameseTrainer(Trainer):
self.summaries_train = tf.get_collection("summaries_train")[0]
self.global_step = tf.get_collection("global_step")[0]
self.from_scratch = False
def get_feed_dict(self, data_shuffler):
......
......@@ -122,7 +122,6 @@ class Trainer(object):
self.session = Session.instance(new=True).session
self.from_scratch = True
def train(self):
"""
Train the network
......@@ -197,7 +196,6 @@ class Trainer(object):
#if not isinstance(self.train_data_shuffler, TFRecord):
# self.thread_pool.join(threads)
def create_network_from_scratch(self,
graph,
validation_graph=None,
......@@ -222,9 +220,6 @@ class Trainer(object):
learning_rate: Learning rate
"""
# Putting together the training data + graph + loss
# Getting the pointer to the placeholders
self.data_ph = self.train_data_shuffler("data", from_queue=True)
self.label_ph = self.train_data_shuffler("label", from_queue=True)
......@@ -243,7 +238,6 @@ class Trainer(object):
self.optimizer_class._learning_rate = self.learning_rate
self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(),
keep_checkpoint_every_n_hours=self.keep_checkpoint_every_n_hours)
......@@ -264,7 +258,7 @@ class Trainer(object):
tf.add_to_collection("summaries_train", self.summaries_train)
# Same business with the validation
if(self.validation_data_shuffler is not None):
if self.validation_data_shuffler is not None:
self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True)
self.validation_label_ph = self.validation_data_shuffler("label", from_queue=True)
......@@ -286,6 +280,24 @@ class Trainer(object):
tf.local_variables_initializer().run(session=self.session)
tf.global_variables_initializer().run(session=self.session)
def load_checkpoint(self, file_name, clear_devices=True):
"""
Load a checkpoint
** Parameters **
file_name:
Name of the metafile to be loaded.
If a directory is passed, the last checkpoint will be loaded
"""
if os.path.isdir(file_name):
checkpoint_path = tf.train.get_checkpoint_state(file_name).model_checkpoint_path
self.saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, tf.train.latest_checkpoint(file_name))
else:
self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
def create_network_from_file(self, file_name, clear_devices=True):
"""
......@@ -295,9 +307,9 @@ class Trainer(object):
file_name: Name of of the checkpoing
"""
#self.saver = tf.train.import_meta_graph(file_name + ".meta", clear_devices=clear_devices)
self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
logger.info("Loading last checkpoint !!")
self.load_checkpoint(file_name, clear_devices=True)
# Loading training graph
self.data_ph = tf.get_collection("data_ph")[0]
......@@ -314,10 +326,9 @@ class Trainer(object):
self.from_scratch = False
# Loading the validation bits
if(self.validation_data_shuffler is not None):
if self.validation_data_shuffler is not None:
self.summaries_validation = tf.get_collection("summaries_validation")[0]
self.validation_graph = tf.get_collection("validation_graph")[0]
self.validation_data_ph = tf.get_collection("validation_data_ph")[0]
self.validation_label = tf.get_collection("validation_label_ph")[0]
......@@ -325,7 +336,6 @@ class Trainer(object):
self.validation_predictor = tf.get_collection("validation_predictor")[0]
self.summaries_validation = tf.get_collection("summaries_validation")[0]
def __del__(self):
tf.reset_default_graph()
......
......@@ -120,7 +120,6 @@ class TripletTrainer(Trainer):
self.session = Session.instance(new=True).session
self.from_scratch = True
def create_network_from_scratch(self,
graph,
optimizer=tf.train.AdamOptimizer(),
......@@ -177,11 +176,9 @@ class TripletTrainer(Trainer):
# Creating the variables
tf.global_variables_initializer().run(session=self.session)
def create_network_from_file(self, model_from_file, clear_devices=True):
def create_network_from_file(self, file_name, clear_devices=True):
#saver = self.architecture.load(self.model_from_file, clear_devices=False)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, model_from_file)
self.load_checkpoint(file_name, clear_devices=clear_devices)
# Loading the graph from the graph pointers
self.graph = dict()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment