Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
15580121
Commit
15580121
authored
Oct 03, 2017
by
Tiago de Freitas Pereira
Browse files
Created mechanism to amend networks
parent
422d8e02
Pipeline
#12947
failed with stages
in 6 minutes and 21 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/test/test_cnn_trainable_variables_select.py
View file @
15580121
...
...
@@ -23,7 +23,7 @@ step2_path = os.path.join(directory, "step2")
slim
=
tf
.
contrib
.
slim
def
base_network
(
train_data_shuffler
,
reuse
=
False
,
get_embedding
=
False
):
def
base_network
(
train_data_shuffler
,
reuse
=
False
,
get_embedding
=
False
,
trainable
=
True
):
if
isinstance
(
train_data_shuffler
,
tf
.
Tensor
):
inputs
=
train_data_shuffler
...
...
@@ -33,11 +33,11 @@ def base_network(train_data_shuffler, reuse=False, get_embedding=False):
# Creating a random network
initializer
=
tf
.
contrib
.
layers
.
xavier_initializer
(
seed
=
seed
)
graph
=
slim
.
conv2d
(
inputs
,
10
,
[
3
,
3
],
activation_fn
=
tf
.
nn
.
relu
,
stride
=
1
,
scope
=
'conv1'
,
weights_initializer
=
initializer
,
reuse
=
reuse
)
weights_initializer
=
initializer
,
reuse
=
reuse
,
trainable
=
trainable
)
graph
=
slim
.
max_pool2d
(
graph
,
[
4
,
4
],
scope
=
'pool1'
)
graph
=
slim
.
flatten
(
graph
,
scope
=
'flatten1'
)
graph
=
slim
.
fully_connected
(
graph
,
30
,
activation_fn
=
None
,
scope
=
'fc1'
,
weights_initializer
=
initializer
,
reuse
=
reuse
)
weights_initializer
=
initializer
,
reuse
=
reuse
,
trainable
=
trainable
)
if
get_embedding
:
graph
=
graph
...
...
@@ -95,17 +95,11 @@ def test_trainable_variables():
writer
.
close
()
# 1 - Create
# 2 - Initialize
# 3 - Minimize with certain variables
# 4 - Load the last checkpoint
######## 1 - BASE NETWORK #########
######## BASE NETWORK #########
tfrecords_filename
=
"mnist_train.tfrecords"
#
create_tf_record(tfrecords_filename, train_data, train_labels)
create_tf_record
(
tfrecords_filename
,
train_data
,
train_labels
)
filename_queue
=
tf
.
train
.
string_input_producer
([
tfrecords_filename
],
num_epochs
=
1
,
name
=
"input"
)
# Doing the first training
...
...
@@ -128,21 +122,22 @@ def test_trainable_variables():
)
trainer
.
train
()
conv1_trained
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
scope
=
'conv1'
)[
0
].
eval
(
session
=
trainer
.
session
)[
0
]
# Saving the cov after first training
conv1_after_first_train
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
scope
=
'conv1'
)[
0
].
eval
(
session
=
trainer
.
session
)[
0
]
del
trainer
del
filename_queue
del
train_data_shuffler
tf
.
reset_default_graph
()
##### Creating an amendment network
######## 2 - AMEMDING NETWORK #########
filename_queue
=
tf
.
train
.
string_input_producer
([
tfrecords_filename
],
num_epochs
=
1
,
name
=
"input"
)
train_data_shuffler
=
TFRecord
(
filename_queue
=
filename_queue
,
batch_size
=
batch_size
)
graph
=
base_network
(
train_data_shuffler
,
get_embedding
=
True
)
# Here I'm creating the base network not trainable
graph
=
base_network
(
train_data_shuffler
,
get_embedding
=
True
,
trainable
=
False
)
graph
=
amendment_network
(
graph
)
loss
=
MeanSoftMaxLoss
(
add_regularization_losses
=
False
)
...
...
@@ -151,7 +146,6 @@ def test_trainable_variables():
analizer
=
None
,
temp_dir
=
step2_path
)
learning_rate
=
constant
(
0.01
,
name
=
"regular_lr"
)
trainer
.
create_network_from_scratch
(
graph
=
graph
,
loss
=
loss
,
...
...
@@ -159,45 +153,18 @@ def test_trainable_variables():
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
),
)
conv1_before_load
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
scope
=
'conv1'
)[
0
].
eval
(
session
=
trainer
.
session
)[
0
]
var_list
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
scope
=
'conv1'
)
+
\
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
scope
=
'fc1'
)
saver
=
tf
.
train
.
Saver
(
var_list
)
saver
.
restore
(
trainer
.
session
,
os
.
path
.
join
(
step1_path
,
"model.ckp"
))
# Loading two layers from the "old" model
external_model
=
os
.
path
.
join
(
step1_path
,
"model.ckp"
)
trainer
.
load_variables_from_external_model
(
external_model
,
var_list
=
[
'conv1'
,
'fc1'
])
conv1_restored
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
scope
=
'conv1'
)[
0
].
eval
(
session
=
trainer
.
session
)[
0
]
assert
numpy
.
allclose
(
conv1_after_first_train
,
conv1_restored
)
trainer
.
train
()
conv1_after_train
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
scope
=
'conv1'
)[
0
].
eval
(
session
=
trainer
.
session
)[
0
]
print
(
conv1_trained
-
conv1_before_load
)
print
(
conv1_trained
-
conv1_restored
)
print
(
conv1_trained
-
conv1_after_train
)
import
ipdb
;
ipdb
.
set_trace
();
x
=
0
# Second round of training
trainer
.
train
()
conv1_after_second_train
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
scope
=
'conv1'
)[
0
].
eval
(
session
=
trainer
.
session
)[
0
]
#var_list = tf.get_collection(tf.GraphKeys.VARIABLES, scope='fc1') + tf.get_collection(tf.GraphKeys.VARIABLES, scope='logits')
#optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss, global_step=global_step, var_list=var_list)
#print("Go ...")
"""
last_iteration = numpy.sum(tf.trainable_variables()[0].eval(session=session)[0])
for i in range(10):
_, l = session.run([optimizer, loss])
current_iteration = numpy.sum(tf.trainable_variables()[0].eval(session=session)[0])
print numpy.abs(current_iteration - last_iteration)
current_iteration = last_iteration
print l
thread_pool.request_stop()
"""
#x = 0
# Since conv1 was set as NON TRAINABLE, both have to match
assert
numpy
.
allclose
(
conv1_after_first_train
,
conv1_after_second_train
)
bob/learn/tensorflow/trainers/Trainer.py
View file @
15580121
...
...
@@ -317,6 +317,28 @@ class Trainer(object):
else
:
self
.
saver
=
tf
.
train
.
import_meta_graph
(
file_name
,
clear_devices
=
clear_devices
)
self
.
saver
.
restore
(
self
.
session
,
tf
.
train
.
latest_checkpoint
(
os
.
path
.
dirname
(
file_name
)))
def
load_variables_from_external_model
(
self
,
file_name
,
var_list
):
"""
Load a set of variables from a given model and update them in the current one
** Parameters **
file_name:
Name of the tensorflow model to be loaded
var_list:
List of variables to be loaded. A tensorflow exception will be raised in case the variable does not exists
"""
assert
len
(
var_list
)
>
0
tf_varlist
=
[]
for
v
in
var_list
:
tf_varlist
+=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
scope
=
v
)
saver
=
tf
.
train
.
Saver
(
tf_varlist
)
saver
.
restore
(
self
.
session
,
file_name
)
def
create_network_from_file
(
self
,
file_name
,
clear_devices
=
True
):
"""
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment