Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
2e5982a7
Commit
2e5982a7
authored
Sep 23, 2017
by
Tiago de Freitas Pereira
Browse files
Improved the loading from the last checkpoint
parent
e6341fae
Pipeline
#12496
failed with stages
in 29 minutes and 27 seconds
Changes
7
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/script/train.py
View file @
2e5982a7
...
...
@@ -7,50 +7,45 @@
Train a Neural network using bob.learn.tensorflow
Usage:
train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg>
--pretrained-net=<arg> --use-gpu --prefetch
] <configuration>
train.py [--iterations=<arg> --validation-interval=<arg> --output-dir=<arg> ] <configuration>
train.py -h | --help
Options:
-h --help Show this screen.
--iterations=<arg> [default: 1000]
--validation-interval=<arg> [default: 100]
--output-dir=<arg> If the directory exists, will try to get the last checkpoint [default: ./logs/]
--pretrained-net=<arg>
"""
from
docopt
import
docopt
import
imp
import
bob.learn.tensorflow
import
tensorflow
as
tf
import
os
def
main
():
args
=
docopt
(
__doc__
,
version
=
'Train Neural Net'
)
USE_GPU
=
args
[
'--use-gpu'
]
OUTPUT_DIR
=
str
(
args
[
'--output-dir'
])
PREFETCH
=
args
[
'--prefetch'
]
ITERATIONS
=
int
(
args
[
'--iterations'
])
PRETRAINED_NET
=
""
if
not
args
[
'--pretrained-net'
]
is
None
:
PRETRAINED_NET
=
str
(
args
[
'--pretrained-net'
])
#PRETRAINED_NET = ""
#if not args['--pretrained-net'] is None:
# PRETRAINED_NET = str(args['--pretrained-net'])
config
=
imp
.
load_source
(
'config'
,
args
[
'<configuration>'
])
# Cleaning all variables in case you are loading the checkpoint
tf
.
reset_default_graph
()
if
os
.
path
.
exists
(
OUTPUT_DIR
)
else
None
# One graph trainer
trainer
=
config
.
Trainer
(
config
.
train_data_shuffler
,
iterations
=
ITERATIONS
,
analizer
=
None
,
temp_dir
=
OUTPUT_DIR
)
if
os
.
path
.
exists
(
OUTPUT_DIR
):
print
(
"Directory already exists, trying to get the last checkpoint"
)
import
ipdb
;
ipdb
.
set_trace
();
trainer
.
create_network_from_file
(
OUTPUT_DIR
)
else
:
# Preparing the architecture
...
...
bob/learn/tensorflow/test/data/train_scripts/siamese.py
View file @
2e5982a7
...
...
@@ -22,16 +22,16 @@ train_data_shuffler = SiameseMemory(train_data, train_labels,
normalizer
=
ScaleFactor
())
### ARCHITECTURE ###
architecture
=
Chopra
(
seed
=
SEED
,
fc1_output
=
10
,
batch_norm
=
False
)
architecture
=
Chopra
(
seed
=
SEED
,
n_classes
=
10
)
### LOSS ###
loss
=
ContrastiveLoss
(
contrastive_margin
=
4.
)
### SOLVER ###
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
0.001
)
### LEARNING RATE ###
learning_rate
=
constant
(
base_learning_rate
=
0.001
)
learning_rate
=
constant
(
base_learning_rate
=
0.01
)
### SOLVER ###
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
)
### Trainer ###
trainer
=
Trainer
bob/learn/tensorflow/test/data/train_scripts/triplet.py
View file @
2e5982a7
...
...
@@ -21,16 +21,19 @@ train_data_shuffler = TripletMemory(train_data, train_labels,
batch_size
=
BATCH_SIZE
)
### ARCHITECTURE ###
architecture
=
Chopra
(
seed
=
SEED
,
fc1_output
=
10
,
batch_norm
=
False
)
architecture
=
Chopra
(
seed
=
SEED
,
n_classes
=
10
)
### LOSS ###
loss
=
TripletLoss
(
margin
=
4.
)
### SOLVER ###
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
0.001
)
### LEARNING RATE ###
learning_rate
=
constant
(
base_learning_rate
=
0.001
)
learning_rate
=
constant
(
base_learning_rate
=
0.01
)
### SOLVER ###
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
)
### Trainer ###
trainer
=
Trainer
bob/learn/tensorflow/test/test_train_script.py
View file @
2e5982a7
...
...
@@ -10,22 +10,29 @@ import shutil
def
test_train_script_softmax
():
directory
=
"./temp/train-script"
train_script
=
pkg_resources
.
resource_filename
(
__name__
,
'./data/train_scripts/softmax.py'
)
train_script
=
'./data/train_scripts/softmax.py'
from
subprocess
import
call
# Start the training
call
([
"./bin/train.py"
,
"--iterations"
,
"5"
,
"--output-dir"
,
directory
,
train_script
])
#shutil.rmtree(directory)
# Continuing from the last checkpoint
call
([
"./bin/train.py"
,
"--iterations"
,
"5"
,
"--output-dir"
,
directory
,
train_script
])
shutil
.
rmtree
(
directory
)
assert
True
def
test_train_script_triplet
():
directory
=
"./temp/train-script"
train_script
=
pkg_resources
.
resource_filename
(
__name__
,
'./data/train_scripts/triplet.py'
)
#train_script = './data/train_scripts/triplet.py'
#from subprocess import call
#call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
#shutil.rmtree(directory)
from
subprocess
import
call
# Start the training
call
([
"./bin/train.py"
,
"--iterations"
,
"5"
,
"--output-dir"
,
directory
,
train_script
])
# Continuing from the last checkpoint
call
([
"./bin/train.py"
,
"--iterations"
,
"5"
,
"--output-dir"
,
directory
,
train_script
])
shutil
.
rmtree
(
directory
)
assert
True
...
...
@@ -33,10 +40,14 @@ def test_train_script_triplet():
def
test_train_script_siamese
():
directory
=
"./temp/train-script"
train_script
=
pkg_resources
.
resource_filename
(
__name__
,
'./data/train_scripts/siamese.py'
)
#train_script = './data/train_scripts/siamese.py'
#from subprocess import call
#call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
#shutil.rmtree(directory)
from
subprocess
import
call
# Start the training
call
([
"./bin/train.py"
,
"--iterations"
,
"5"
,
"--output-dir"
,
directory
,
train_script
])
# Continuing from the last checkpoint
call
([
"./bin/train.py"
,
"--iterations"
,
"5"
,
"--output-dir"
,
directory
,
train_script
])
shutil
.
rmtree
(
directory
)
assert
True
bob/learn/tensorflow/trainers/SiameseTrainer.py
View file @
2e5982a7
...
...
@@ -179,9 +179,7 @@ class SiameseTrainer(Trainer):
def
create_network_from_file
(
self
,
model_from_file
,
clear_devices
=
True
):
#saver = self.architecture.load(self.model_from_file, clear_devices=False)
self
.
saver
=
tf
.
train
.
import_meta_graph
(
model_from_file
+
".meta"
,
clear_devices
=
clear_devices
)
self
.
saver
.
restore
(
self
.
session
,
model_from_file
)
self
.
load_checkpoint
(
model_from_file
,
clear_devices
=
clear_devices
)
# Loading the graph from the graph pointers
self
.
graph
=
dict
()
...
...
@@ -206,7 +204,6 @@ class SiameseTrainer(Trainer):
self
.
summaries_train
=
tf
.
get_collection
(
"summaries_train"
)[
0
]
self
.
global_step
=
tf
.
get_collection
(
"global_step"
)[
0
]
self
.
from_scratch
=
False
def
get_feed_dict
(
self
,
data_shuffler
):
...
...
bob/learn/tensorflow/trainers/Trainer.py
View file @
2e5982a7
...
...
@@ -122,7 +122,6 @@ class Trainer(object):
self
.
session
=
Session
.
instance
(
new
=
True
).
session
self
.
from_scratch
=
True
def
train
(
self
):
"""
Train the network
...
...
@@ -197,7 +196,6 @@ class Trainer(object):
#if not isinstance(self.train_data_shuffler, TFRecord):
# self.thread_pool.join(threads)
def
create_network_from_scratch
(
self
,
graph
,
validation_graph
=
None
,
...
...
@@ -222,9 +220,6 @@ class Trainer(object):
learning_rate: Learning rate
"""
# Putting together the training data + graph + loss
# Getting the pointer to the placeholders
self
.
data_ph
=
self
.
train_data_shuffler
(
"data"
,
from_queue
=
True
)
self
.
label_ph
=
self
.
train_data_shuffler
(
"label"
,
from_queue
=
True
)
...
...
@@ -243,7 +238,6 @@ class Trainer(object):
self
.
optimizer_class
.
_learning_rate
=
self
.
learning_rate
self
.
optimizer
=
self
.
optimizer_class
.
minimize
(
self
.
predictor
,
global_step
=
self
.
global_step
)
# Saving all the variables
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
tf
.
global_variables
()
+
tf
.
local_variables
(),
keep_checkpoint_every_n_hours
=
self
.
keep_checkpoint_every_n_hours
)
...
...
@@ -264,7 +258,7 @@ class Trainer(object):
tf
.
add_to_collection
(
"summaries_train"
,
self
.
summaries_train
)
# Same business with the validation
if
(
self
.
validation_data_shuffler
is
not
None
)
:
if
self
.
validation_data_shuffler
is
not
None
:
self
.
validation_data_ph
=
self
.
validation_data_shuffler
(
"data"
,
from_queue
=
True
)
self
.
validation_label_ph
=
self
.
validation_data_shuffler
(
"label"
,
from_queue
=
True
)
...
...
@@ -286,6 +280,24 @@ class Trainer(object):
tf
.
local_variables_initializer
().
run
(
session
=
self
.
session
)
tf
.
global_variables_initializer
().
run
(
session
=
self
.
session
)
def
load_checkpoint
(
self
,
file_name
,
clear_devices
=
True
):
"""
Load a checkpoint
** Parameters **
file_name:
Name of the metafile to be loaded.
If a directory is passed, the last checkpoint will be loaded
"""
if
os
.
path
.
isdir
(
file_name
):
checkpoint_path
=
tf
.
train
.
get_checkpoint_state
(
file_name
).
model_checkpoint_path
self
.
saver
=
tf
.
train
.
import_meta_graph
(
checkpoint_path
+
".meta"
,
clear_devices
=
clear_devices
)
self
.
saver
.
restore
(
self
.
session
,
tf
.
train
.
latest_checkpoint
(
file_name
))
else
:
self
.
saver
=
tf
.
train
.
import_meta_graph
(
file_name
,
clear_devices
=
clear_devices
)
self
.
saver
.
restore
(
self
.
session
,
tf
.
train
.
latest_checkpoint
(
os
.
path
.
dirname
(
file_name
)))
def
create_network_from_file
(
self
,
file_name
,
clear_devices
=
True
):
"""
...
...
@@ -295,9 +307,9 @@ class Trainer(object):
file_name: Name of of the checkpoing
"""
#self.saver = tf.train.import_meta_graph(file_name + ".meta", clear_devices=clear_devices)
self
.
saver
=
tf
.
train
.
import_meta_graph
(
file_name
,
clear_devices
=
clear_devices
)
self
.
saver
.
restore
(
self
.
session
,
tf
.
train
.
latest_checkpoint
(
os
.
path
.
dirname
(
file_name
))
)
logger
.
info
(
"Loading last checkpoint !!"
)
self
.
load_checkpoint
(
file_name
,
clear_devices
=
True
)
# Loading training graph
self
.
data_ph
=
tf
.
get_collection
(
"data_ph"
)[
0
]
...
...
@@ -314,10 +326,9 @@ class Trainer(object):
self
.
from_scratch
=
False
# Loading the validation bits
if
(
self
.
validation_data_shuffler
is
not
None
)
:
if
self
.
validation_data_shuffler
is
not
None
:
self
.
summaries_validation
=
tf
.
get_collection
(
"summaries_validation"
)[
0
]
self
.
validation_graph
=
tf
.
get_collection
(
"validation_graph"
)[
0
]
self
.
validation_data_ph
=
tf
.
get_collection
(
"validation_data_ph"
)[
0
]
self
.
validation_label
=
tf
.
get_collection
(
"validation_label_ph"
)[
0
]
...
...
@@ -325,7 +336,6 @@ class Trainer(object):
self
.
validation_predictor
=
tf
.
get_collection
(
"validation_predictor"
)[
0
]
self
.
summaries_validation
=
tf
.
get_collection
(
"summaries_validation"
)[
0
]
def
__del__
(
self
):
tf
.
reset_default_graph
()
...
...
bob/learn/tensorflow/trainers/TripletTrainer.py
View file @
2e5982a7
...
...
@@ -120,7 +120,6 @@ class TripletTrainer(Trainer):
self
.
session
=
Session
.
instance
(
new
=
True
).
session
self
.
from_scratch
=
True
def
create_network_from_scratch
(
self
,
graph
,
optimizer
=
tf
.
train
.
AdamOptimizer
(),
...
...
@@ -177,11 +176,9 @@ class TripletTrainer(Trainer):
# Creating the variables
tf
.
global_variables_initializer
().
run
(
session
=
self
.
session
)
def
create_network_from_file
(
self
,
model_from_fil
e
,
clear_devices
=
True
):
def
create_network_from_file
(
self
,
file_nam
e
,
clear_devices
=
True
):
#saver = self.architecture.load(self.model_from_file, clear_devices=False)
self
.
saver
=
tf
.
train
.
import_meta_graph
(
model_from_file
+
".meta"
,
clear_devices
=
clear_devices
)
self
.
saver
.
restore
(
self
.
session
,
model_from_file
)
self
.
load_checkpoint
(
file_name
,
clear_devices
=
clear_devices
)
# Loading the graph from the graph pointers
self
.
graph
=
dict
()
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment