Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
66868b60
Commit
66868b60
authored
Jan 04, 2017
by
Tiago de Freitas Pereira
Browse files
IMplementing issue 19
parent
2ab6b68d
Pipeline
#6205
failed with stages
in 26 minutes and 54 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/script/train.py
0 → 100644
View file @
66868b60
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 04 Jan 2017 18:00:36 CET
"""
Train a Neural network using bob.learn.tensorflow
Usage:
train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> --pretrained-net=<arg> --use-gpu --prefetch ] <configuration>
train.py -h | --help
Options:
-h --help Show this screen.
--iterations=<arg> [default: 1000]
--validation-interval=<arg> [default: 100]
--output-dir=<arg> [default: ./logs/]
--pretrained-net=<arg>
"""
from
docopt
import
docopt
import
imp
def
main
():
args
=
docopt
(
__doc__
,
version
=
'Train Neural Net'
)
#ITERATIONS = int(args['--iterations'])
#VALIDATION_TEST = int(args['--validation-interval'])
#USE_GPU = args['--use-gpu']
#OUTPUT_DIR = str(args['--output-dir'])
#PREFETCH = args['--prefetch']
#if args['--pretrained-net'] is None:
# PRETRAINED_NET = ""
#else:
# PRETRAINED_NET = str(args['--pretrained-net'])
config
=
imp
.
load_source
(
'config'
,
args
[
'<configuration>'
])
trainer
=
config
.
Trainer
(
architecture
=
config
.
architecture
,
loss
=
config
.
loss
,
iterations
=
int
(
args
[
'--iterations'
]),
analizer
=
None
,
prefetch
=
args
[
'--prefetch'
],
learning_rate
=
config
.
learning_rate
,
temp_dir
=
args
[
'--output-dir'
],
model_from_file
=
config
.
model_from_file
)
import
ipdb
;
ipdb
.
set_trace
();
trainer
.
train
(
config
.
train_data_shuffler
)
bob/learn/tensorflow/trainers/Trainer.py
View file @
66868b60
...
...
@@ -16,7 +16,11 @@ from bob.learn.tensorflow.datashuffler import OnlineSampling
from
bob.learn.tensorflow.utils.session
import
Session
from
.learning_rate
import
constant
logger
=
bob
.
core
.
log
.
setup
(
"bob.learn.tensorflow"
)
#logger = bob.core.log.setup("bob.learn.tensorflow")
import
logging
logger
=
logging
.
getLogger
(
"bob.learn"
)
class
Trainer
(
object
):
...
...
setup.py
View file @
66868b60
...
...
@@ -47,7 +47,8 @@ setup(
# scripts should be declared using this entry:
'console_scripts'
:
[
'compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main'
'compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main'
,
'train.py = bob.learn.tensorflow.script.train:main'
],
},
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment