Skip to content
Snippets Groups Projects
Commit 85089584 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Integrating with gridtk

parent 2e5982a7
Branches
Tags
1 merge request!14gridtk integration
Pipeline #
......@@ -7,13 +7,14 @@
Train a Neural network using bob.learn.tensorflow
Usage:
train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> ] <configuration>
train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> ] <configuration> [grid <jobs>]
train.py -h | --help
Options:
-h --help Show this screen.
--iterations=<arg> [default: 1000]
--validation-interval=<arg> [default: 100]
--output-dir=<arg> If the directory exists, will try to get the last checkpoint [default: ./logs/]
-h --help Show this screen.
--iterations=<arg> Number of iteratiosn [default: 1000]
--validation-interval=<arg> Validata every n iteratiosn [default: 500]
--output-dir=<arg> If the directory exists, will try to get the last checkpoint [default: ./logs/]
"""
from docopt import docopt
......@@ -21,31 +22,51 @@ import imp
import bob.learn.tensorflow
import tensorflow as tf
import os
import sys
def dump_commandline():
command_line = ""
for command in sys.argv:
if command == "grid":
break
command_line += "{0} ".format(command)
return command_line
def main():
args = docopt(__doc__, version='Train Neural Net')
OUTPUT_DIR = str(args['--output-dir'])
ITERATIONS = int(args['--iterations'])
output_dir = str(args['--output-dir'])
iterations = int(args['--iterations'])
grid = int(args['grid'])
if grid:
jobs = int(args['<jobs>'])
import gridtk
job_manager = gridtk.sge.JobManagerSGE()
command = dump_commandline()
dependencies = []
#PRETRAINED_NET = ""
#if not args['--pretrained-net'] is None:
# PRETRAINED_NET = str(args['--pretrained-net'])
for i in range(jobs):
job_id = job_manager.submit(command, dependencies=dependencies)
dependencies = [job_id]
config = imp.load_source('config', args['<configuration>'])
# Cleaning all variables in case you are loading the checkpoint
tf.reset_default_graph() if os.path.exists(OUTPUT_DIR) else None
tf.reset_default_graph() if os.path.exists(output_dir) else None
# One graph trainer
trainer = config.Trainer(config.train_data_shuffler,
iterations=ITERATIONS,
iterations=iterations,
analizer=None,
temp_dir=OUTPUT_DIR)
if os.path.exists(OUTPUT_DIR):
temp_dir=output_dir)
if os.path.exists(output_dir):
print("Directory already exists, trying to get the last checkpoint")
trainer.create_network_from_file(OUTPUT_DIR)
trainer.create_network_from_file(output_dir)
else:
# Preparing the architecture
......
......@@ -5,6 +5,7 @@
import pkg_resources
import shutil
import tensorflow as tf
def test_train_script_softmax():
......@@ -18,7 +19,9 @@ def test_train_script_softmax():
# Continuing from the last checkpoint
call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
shutil.rmtree(directory)
assert True
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
def test_train_script_triplet():
......@@ -34,7 +37,8 @@ def test_train_script_triplet():
shutil.rmtree(directory)
assert True
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
def test_train_script_siamese():
......@@ -50,4 +54,5 @@ def test_train_script_siamese():
shutil.rmtree(directory)
assert True
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment