From 66868b60b51f9e57427620c03657c493a2da499a Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Wed, 4 Jan 2017 19:17:20 +0100
Subject: [PATCH] IMplementing issue 19

---
 bob/learn/tensorflow/script/train.py     | 52 ++++++++++++++++++++++++
 bob/learn/tensorflow/trainers/Trainer.py |  6 ++-
 setup.py                                 |  3 +-
 3 files changed, 59 insertions(+), 2 deletions(-)
 create mode 100644 bob/learn/tensorflow/script/train.py

diff --git a/bob/learn/tensorflow/script/train.py b/bob/learn/tensorflow/script/train.py
new file mode 100644
index 00000000..949b1a66
--- /dev/null
+++ b/bob/learn/tensorflow/script/train.py
@@ -0,0 +1,52 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+# @date: Wed 04 Jan 2017 18:00:36 CET
+
+"""
+Train a Neural network using bob.learn.tensorflow
+
+Usage:
+  train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> --pretrained-net=<arg> --use-gpu --prefetch ] <configuration>
+  train.py -h | --help
+Options:
+  -h --help     Show this screen.
+  --iterations=<arg>   [default: 1000]
+  --validation-interval=<arg>   [default: 100]
+  --output-dir=<arg>    [default: ./logs/]
+  --pretrained-net=<arg>
+"""
+
+
+from docopt import docopt
+import imp
+
+
+def main():
+    args = docopt(__doc__, version='Train Neural Net')
+
+    #ITERATIONS = int(args['--iterations'])
+    #VALIDATION_TEST = int(args['--validation-interval'])
+    #USE_GPU = args['--use-gpu']
+    #OUTPUT_DIR = str(args['--output-dir'])
+    #PREFETCH = args['--prefetch']
+    #if args['--pretrained-net'] is None:
+    #    PRETRAINED_NET = ""
+    #else:
+    #    PRETRAINED_NET = str(args['--pretrained-net'])
+
+    config = imp.load_source('config', args['<configuration>'])
+
+    trainer = config.Trainer(architecture=config.architecture,
+                             loss=config.loss,
+                             iterations=int(args['--iterations']),
+                             analizer=None,
+                             prefetch=args['--prefetch'],
+                             learning_rate=config.learning_rate,
+                             temp_dir=args['--output-dir'],
+                             model_from_file=config.model_from_file
+                             )
+
+    import ipdb; ipdb.set_trace();
+    trainer.train(config.train_data_shuffler)
+
diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py
index f31c9535..3e3a6002 100644
--- a/bob/learn/tensorflow/trainers/Trainer.py
+++ b/bob/learn/tensorflow/trainers/Trainer.py
@@ -16,7 +16,11 @@ from bob.learn.tensorflow.datashuffler import OnlineSampling
 from bob.learn.tensorflow.utils.session import Session
 from .learning_rate import constant
 
-logger = bob.core.log.setup("bob.learn.tensorflow")
+#logger = bob.core.log.setup("bob.learn.tensorflow")
+
+import logging
+logger = logging.getLogger("bob.learn")
+
 
 
 class Trainer(object):
diff --git a/setup.py b/setup.py
index 3d6fcd45..c626f6e8 100644
--- a/setup.py
+++ b/setup.py
@@ -47,7 +47,8 @@ setup(
 
         # scripts should be declared using this entry:
         'console_scripts': [
-            'compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main'
+            'compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main',
+            'train.py = bob.learn.tensorflow.script.train:main'
         ],
 
     },
-- 
GitLab