Skip to content
Snippets Groups Projects
Commit e6864f8a authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

gridtk integration

parent 28e38b2b
No related branches found
No related tags found
1 merge request!14gridtk integration
Pipeline #
......@@ -7,7 +7,7 @@
Train a Neural network using bob.learn.tensorflow
Usage:
train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> ] <configuration> [grid <jobs> <job-name> <queue>]
train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> ] <configuration> [grid --n-jobs=<arg> --job-name=<job-name> --queue=<arg>]
train.py -h | --help
Options:
......@@ -15,6 +15,9 @@ Options:
--iterations=<arg> Number of iteratiosn [default: 1000]
--validation-interval=<arg> Validata every n iteratiosn [default: 500]
--output-dir=<arg> If the directory exists, will try to get the last checkpoint [default: ./logs/]
--n-jobs=<arg> Number of jobs submitted to the grid [default: 3]
--job-name=<arg> Job name [default: TF]
--queue=<arg> SGE queue name [default: q_gpu]
"""
from docopt import docopt
......@@ -24,6 +27,9 @@ import tensorflow as tf
import os
import sys
import logging
logger = logging.getLogger("bob.learn")
def dump_commandline():
......@@ -37,28 +43,33 @@ def dump_commandline():
def main():
args = docopt(__doc__, version='Train Neural Net')
output_dir = str(args['--output-dir'])
iterations = int(args['--iterations'])
grid = int(args['grid'])
if grid:
jobs = int(args['<jobs>'])
job_name = args['<job-name>']
queue = args['<queue>']
# Submitting jobs to SGE
jobs = int(args['--n-jobs'])
job_name = args['--job-name']
queue = args['--queue']
import gridtk
job_manager = gridtk.sge.JobManagerSGE()
command = dump_commandline()
dependencies = []
total_jobs = []
kwargs = {"env": ["LD_LIBRARY_PATH=/idiap/user/tpereira/cuda/cuda-8.0/lib64:/idiap/user/tpereira/cuda/cudnn-8.0-linux-x64-v5.1/lib64:/idiap/user/tpereira/cuda/cuda-8.0/bin"]}
for i in range(jobs):
job_id = job_manager.submit(command, queue=queue, dependencies=dependencies,
name=job_name)
name=job_name + "{0}".format(i), **kwargs)
dependencies = [job_id]
total_jobs.append(job_id)
print("Submitted the jobs {0}".format(total_jobs))
logger.info("Submitted the jobs {0}".format(total_jobs))
return True
config = imp.load_source('config', args['<configuration>'])
......@@ -72,7 +83,7 @@ def main():
analizer=None,
temp_dir=output_dir)
if os.path.exists(output_dir):
print("Directory already exists, trying to get the last checkpoint")
logger.info("Directory already exists, trying to get the last checkpoint")
trainer.create_network_from_file(output_dir)
else:
......
......@@ -9,6 +9,8 @@ import tensorflow as tf
def test_train_script_softmax():
tf.reset_default_graph()
directory = "./temp/train-script"
train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/softmax.py')
......@@ -25,6 +27,8 @@ def test_train_script_softmax():
def test_train_script_triplet():
tf.reset_default_graph()
directory = "./temp/train-script"
train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/triplet.py')
......@@ -42,6 +46,8 @@ def test_train_script_triplet():
def test_train_script_siamese():
tf.reset_default_graph()
directory = "./temp/train-script"
train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/siamese.py')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment