Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
b2d5c736
Commit
b2d5c736
authored
Apr 22, 2017
by
Tiago Pereira
Browse files
Set the training from file
parent
47480241
Changes
2
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/test/test_cnn_pretrained_model.py
View file @
b2d5c736
...
...
@@ -6,7 +6,7 @@
import
numpy
import
bob.io.base
import
os
from
bob.learn.tensorflow.datashuffler
import
Memory
,
ImageAugmentation
,
TripletMemory
,
SiameseMemory
from
bob.learn.tensorflow.datashuffler
import
Memory
,
ImageAugmentation
,
TripletMemory
,
SiameseMemory
,
ScaleFactor
from
bob.learn.tensorflow.loss
import
BaseLoss
,
TripletLoss
,
ContrastiveLoss
from
bob.learn.tensorflow.trainers
import
Trainer
,
constant
,
TripletTrainer
,
SiameseTrainer
from
bob.learn.tensorflow.utils
import
load_mnist
...
...
@@ -33,17 +33,19 @@ def scratch_network(input_pl):
# Creating a random network
slim
=
tf
.
contrib
.
slim
initializer
=
tf
.
contrib
.
layers
.
xavier_initializer
(
uniform
=
False
,
dtype
=
tf
.
float32
,
seed
=
10
)
with
tf
.
device
(
"/cpu:0"
):
initializer
=
tf
.
contrib
.
layers
.
xavier_initializer
(
uniform
=
False
,
dtype
=
tf
.
float32
,
seed
=
10
)
scratch
=
slim
.
conv2d
(
input_pl
,
10
,
3
,
activation_fn
=
tf
.
nn
.
tanh
,
stride
=
1
,
weights_initializer
=
initializer
,
scope
=
'conv1'
)
scratch
=
slim
.
flatten
(
scratch
,
scope
=
'flatten1'
)
scratch
=
slim
.
fully_connected
(
scratch
,
10
,
weights_initializer
=
initializer
,
activation_fn
=
None
,
scope
=
'fc1'
)
scratch
=
slim
.
conv2d
(
input_pl
,
16
,
[
3
,
3
],
activation_fn
=
tf
.
nn
.
relu
,
stride
=
1
,
weights_initializer
=
initializer
,
scope
=
'conv1'
)
scratch
=
slim
.
max_pool2d
(
scratch
,
kernel_size
=
[
2
,
2
],
scope
=
'pool1'
)
scratch
=
slim
.
flatten
(
scratch
,
scope
=
'flatten1'
)
scratch
=
slim
.
fully_connected
(
scratch
,
10
,
weights_initializer
=
initializer
,
activation_fn
=
None
,
scope
=
'fc1'
)
return
scratch
...
...
@@ -58,7 +60,8 @@ def test_cnn_pretrained():
train_data_shuffler
=
Memory
(
train_data
,
train_labels
,
input_shape
=
[
None
,
28
,
28
,
1
],
batch_size
=
batch_size
,
data_augmentation
=
data_augmentation
)
data_augmentation
=
data_augmentation
,
normalizer
=
ScaleFactor
())
validation_data
=
numpy
.
reshape
(
validation_data
,
(
validation_data
.
shape
[
0
],
28
,
28
,
1
))
directory
=
"./temp/cnn"
...
...
@@ -81,39 +84,35 @@ def test_cnn_pretrained():
)
trainer
.
create_network_from_scratch
(
graph
=
graph
,
loss
=
loss
,
learning_rate
=
constant
(
0.
0
1
,
name
=
"regular_lr"
),
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
0.
0
1
),
learning_rate
=
constant
(
0.1
,
name
=
"regular_lr"
),
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
0.1
),
)
trainer
.
train
()
accuracy
=
validate_network
(
embedding
,
validation_data
,
validation_labels
)
assert
accuracy
>
80
tf
.
reset_default_graph
()
del
graph
del
loss
del
trainer
del
embedding
# Training the network using a pre trained model
loss
=
BaseLoss
(
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
,
tf
.
reduce_mean
,
name
=
"loss"
)
graph
=
scratch_network
(
input_pl
)
# One graph trainer
trainer
=
Trainer
(
train_data_shuffler
,
iterations
=
iterations
,
iterations
=
iterations
*
3
,
analizer
=
None
,
temp_dir
=
directory
)
trainer
.
create_network_from_file
(
os
.
path
.
join
(
directory
,
"model.ckp"
))
import
ipdb
;
ipdb
.
set_trace
()
trainer
.
train
()
embedding
=
Embedding
(
trainer
.
data_ph
,
trainer
.
graph
)
accuracy
=
validate_network
(
embedding
,
validation_data
,
validation_labels
)
assert
accuracy
>
90
shutil
.
rmtree
(
directory
)
shutil
.
rmtree
(
directory2
)
del
graph
del
loss
del
trainer
...
...
bob/learn/tensorflow/trainers/Trainer.py
View file @
b2d5c736
...
...
@@ -131,8 +131,6 @@ class Trainer(object):
learning_rate
=
None
,
):
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
tf
.
global_variables
())
self
.
data_ph
=
self
.
train_data_shuffler
(
"data"
)
self
.
label_ph
=
self
.
train_data_shuffler
(
"label"
)
self
.
graph
=
graph
...
...
@@ -144,6 +142,10 @@ class Trainer(object):
# TODO: find an elegant way to provide this as a parameter of the trainer
self
.
global_step
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
"global_step"
)
# Saving all the variables
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
tf
.
global_variables
())
tf
.
add_to_collection
(
"global_step"
,
self
.
global_step
)
tf
.
add_to_collection
(
"graph"
,
self
.
graph
)
...
...
@@ -161,6 +163,7 @@ class Trainer(object):
self
.
summaries_train
=
self
.
create_general_summary
()
tf
.
add_to_collection
(
"summaries_train"
,
self
.
summaries_train
)
# Creating the variables
tf
.
global_variables_initializer
().
run
(
session
=
self
.
session
)
...
...
@@ -173,15 +176,14 @@ class Trainer(object):
train_data_shuffler: Data shuffler for training
validation_data_shuffler: Data shuffler for validation
"""
#saver = self.architecture.load(self.model_from_file, clear_devices=False)
self
.
saver
=
tf
.
train
.
import_meta_graph
(
model_from_file
+
".meta"
)
self
.
saver
.
restore
(
self
.
session
,
model_from_file
)
# Loading training graph
self
.
data_ph
=
tf
.
get_collection
(
"data_ph"
)
self
.
label_ph
=
tf
.
get_collection
(
"label_ph"
)
self
.
data_ph
=
tf
.
get_collection
(
"data_ph"
)
[
0
]
self
.
label_ph
=
tf
.
get_collection
(
"label_ph"
)
[
0
]
self
.
graph
=
tf
.
get_collection
(
"graph"
)[
0
]
self
.
predictor
=
tf
.
get_collection
(
"predictor"
)[
0
]
...
...
@@ -194,10 +196,7 @@ class Trainer(object):
self
.
from_scratch
=
False
# Creating the variables
tf
.
global_variables_initializer
().
run
(
session
=
self
.
session
)
import
ipdb
;
ipdb
.
set_trace
()
x
=
0
#tf.global_variables_initializer().run(session=self.session)
def
__del__
(
self
):
tf
.
reset_default_graph
()
...
...
@@ -356,7 +355,7 @@ class Trainer(object):
if
step
%
self
.
snapshot
==
0
:
logger
.
info
(
"Taking snapshot"
)
path
=
os
.
path
.
join
(
self
.
temp_dir
,
'model_snapshot{0}.ckp'
.
format
(
step
))
self
.
saver
.
save
(
self
.
session
,
path
)
self
.
saver
.
save
(
self
.
session
,
path
,
global_step
=
step
)
#self.architecture.save(saver, path)
logger
.
info
(
"Training finally finished"
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment