Skip to content
Snippets Groups Projects
Commit 66868b60 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

IMplementing issue 19

parent 2ab6b68d
No related branches found
No related tags found
1 merge request!4Issue 19
Pipeline #
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 04 Jan 2017 18:00:36 CET
"""
Train a Neural network using bob.learn.tensorflow
Usage:
train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> --pretrained-net=<arg> --use-gpu --prefetch ] <configuration>
train.py -h | --help
Options:
-h --help Show this screen.
--iterations=<arg> [default: 1000]
--validation-interval=<arg> [default: 100]
--output-dir=<arg> [default: ./logs/]
--pretrained-net=<arg>
"""
from docopt import docopt
import imp
def main():
args = docopt(__doc__, version='Train Neural Net')
#ITERATIONS = int(args['--iterations'])
#VALIDATION_TEST = int(args['--validation-interval'])
#USE_GPU = args['--use-gpu']
#OUTPUT_DIR = str(args['--output-dir'])
#PREFETCH = args['--prefetch']
#if args['--pretrained-net'] is None:
# PRETRAINED_NET = ""
#else:
# PRETRAINED_NET = str(args['--pretrained-net'])
config = imp.load_source('config', args['<configuration>'])
trainer = config.Trainer(architecture=config.architecture,
loss=config.loss,
iterations=int(args['--iterations']),
analizer=None,
prefetch=args['--prefetch'],
learning_rate=config.learning_rate,
temp_dir=args['--output-dir'],
model_from_file=config.model_from_file
)
import ipdb; ipdb.set_trace();
trainer.train(config.train_data_shuffler)
...@@ -16,7 +16,11 @@ from bob.learn.tensorflow.datashuffler import OnlineSampling ...@@ -16,7 +16,11 @@ from bob.learn.tensorflow.datashuffler import OnlineSampling
from bob.learn.tensorflow.utils.session import Session from bob.learn.tensorflow.utils.session import Session
from .learning_rate import constant from .learning_rate import constant
logger = bob.core.log.setup("bob.learn.tensorflow") #logger = bob.core.log.setup("bob.learn.tensorflow")
import logging
logger = logging.getLogger("bob.learn")
class Trainer(object): class Trainer(object):
......
...@@ -47,7 +47,8 @@ setup( ...@@ -47,7 +47,8 @@ setup(
# scripts should be declared using this entry: # scripts should be declared using this entry:
'console_scripts': [ 'console_scripts': [
'compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main' 'compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main',
'train.py = bob.learn.tensorflow.script.train:main'
], ],
}, },
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment