From 66868b60b51f9e57427620c03657c493a2da499a Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Wed, 4 Jan 2017 19:17:20 +0100 Subject: [PATCH] IMplementing issue 19 --- bob/learn/tensorflow/script/train.py | 52 ++++++++++++++++++++++++ bob/learn/tensorflow/trainers/Trainer.py | 6 ++- setup.py | 3 +- 3 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 bob/learn/tensorflow/script/train.py diff --git a/bob/learn/tensorflow/script/train.py b/bob/learn/tensorflow/script/train.py new file mode 100644 index 00000000..949b1a66 --- /dev/null +++ b/bob/learn/tensorflow/script/train.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> +# @date: Wed 04 Jan 2017 18:00:36 CET + +""" +Train a Neural network using bob.learn.tensorflow + +Usage: + train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> --pretrained-net=<arg> --use-gpu --prefetch ] <configuration> + train.py -h | --help +Options: + -h --help Show this screen. + --iterations=<arg> [default: 1000] + --validation-interval=<arg> [default: 100] + --output-dir=<arg> [default: ./logs/] + --pretrained-net=<arg> +""" + + +from docopt import docopt +import imp + + +def main(): + args = docopt(__doc__, version='Train Neural Net') + + #ITERATIONS = int(args['--iterations']) + #VALIDATION_TEST = int(args['--validation-interval']) + #USE_GPU = args['--use-gpu'] + #OUTPUT_DIR = str(args['--output-dir']) + #PREFETCH = args['--prefetch'] + #if args['--pretrained-net'] is None: + # PRETRAINED_NET = "" + #else: + # PRETRAINED_NET = str(args['--pretrained-net']) + + config = imp.load_source('config', args['<configuration>']) + + trainer = config.Trainer(architecture=config.architecture, + loss=config.loss, + iterations=int(args['--iterations']), + analizer=None, + prefetch=args['--prefetch'], + learning_rate=config.learning_rate, + temp_dir=args['--output-dir'], + model_from_file=config.model_from_file + ) + + import ipdb; ipdb.set_trace(); + trainer.train(config.train_data_shuffler) + diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py index f31c9535..3e3a6002 100644 --- a/bob/learn/tensorflow/trainers/Trainer.py +++ b/bob/learn/tensorflow/trainers/Trainer.py @@ -16,7 +16,11 @@ from bob.learn.tensorflow.datashuffler import OnlineSampling from bob.learn.tensorflow.utils.session import Session from .learning_rate import constant -logger = bob.core.log.setup("bob.learn.tensorflow") +#logger = bob.core.log.setup("bob.learn.tensorflow") + +import logging +logger = logging.getLogger("bob.learn") + class Trainer(object): diff --git a/setup.py b/setup.py index 3d6fcd45..c626f6e8 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,8 @@ setup( # scripts should be declared using this entry: 'console_scripts': [ - 'compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main' + 'compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main', + 'train.py = bob.learn.tensorflow.script.train:main' ], }, -- GitLab