Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
4402ab49
Commit
4402ab49
authored
Aug 30, 2016
by
Tiago de Freitas Pereira
Browse files
Defined a load and save functions
parent
21fbd15b
Changes
2
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/network/SequenceNetwork.py
View file @
4402ab49
...
...
@@ -31,7 +31,9 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
self
.
sequence_net
=
OrderedDict
()
self
.
feature_layer
=
feature_layer
self
.
saver
=
None
self
.
input_divide
=
1.
self
.
input_subtract
=
0.
#self.saver = None
def
add
(
self
,
layer
):
"""
...
...
@@ -92,16 +94,87 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
return
variables
def
save
(
self
,
hdf5
,
step
=
None
):
"""
Save the state of the network in HDF5 format
:param session:
:param hdf5:
:param step:
:return:
"""
# Directory that stores the tensorflow variables
hdf5
.
create_group
(
'/tensor_flow'
)
hdf5
.
cd
(
'/tensor_flow'
)
if
step
is
not
None
:
group_name
=
'/step_{0}'
.
format
(
step
)
hdf5
.
create_group
(
group_name
)
hdf5
.
cd
(
group_name
)
# Iterating the variables of the model
for
v
in
self
.
dump_variables
().
keys
():
hdf5
.
set
(
v
,
self
.
dump_variables
()[
v
].
eval
())
hdf5
.
cd
(
'..'
)
if
step
is
not
None
:
hdf5
.
cd
(
'..'
)
hdf5
.
set
(
'input_divide'
,
self
.
input_divide
)
hdf5
.
set
(
'input_subtract'
,
self
.
input_subtract
)
def
load
(
self
,
hdf5
,
shape
,
session
=
None
):
if
session
is
None
:
session
=
tf
.
Session
()
# Loading the normalization parameters
self
.
input_divide
=
hdf5
.
read
(
'input_divide'
)
self
.
input_subtract
=
hdf5
.
read
(
'input_subtract'
)
# Loading variables
place_holder
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
shape
,
name
=
"load"
)
self
.
compute_graph
(
place_holder
)
tf
.
initialize_all_variables
().
run
(
session
=
session
)
hdf5
.
cd
(
'/tensor_flow'
)
for
k
in
self
.
sequence_net
:
# TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
if
not
isinstance
(
self
.
sequence_net
[
k
],
MaxPooling
):
#self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name))
self
.
sequence_net
[
k
].
W
.
assign
(
hdf5
.
read
(
self
.
sequence_net
[
k
].
W
.
name
)).
eval
(
session
=
session
)
session
.
run
(
self
.
sequence_net
[
k
].
W
)
self
.
sequence_net
[
k
].
b
.
assign
(
hdf5
.
read
(
self
.
sequence_net
[
k
].
b
.
name
)).
eval
(
session
=
session
)
session
.
run
(
self
.
sequence_net
[
k
].
b
)
#if self.saver is None:
# variables = self.dump_variables()
# variables['input_divide'] = self.input_divide
# variables['input_subtract'] = self.input_subtract
# self.saver = tf.train.Saver(variables)
#self.saver.restore(session, path)
"""
def save(self, session, path, step=None):
if self.saver is None:
self
.
saver
=
tf
.
train
.
Saver
(
self
.
dump_variables
())
variables = self.dump_variables()
variables['mean'] = tf.Variable(10.0)
#import ipdb; ipdb.set_trace()
tf.initialize_all_variables().run()
self.saver = tf.train.Saver(variables)
if step is None:
return self.saver.save(session, os.path.join(path, "model.ckpt"))
else:
return self.saver.save(session, os.path.join(path, "model" + str(step) + ".ckpt"))
def load(self, path, shape, session=None):
if session is None:
...
...
@@ -113,6 +186,10 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
tf.initialize_all_variables().run(session=session)
if self.saver is None:
self
.
saver
=
tf
.
train
.
Saver
(
self
.
dump_variables
())
variables = self.dump_variables()
variables['input_divide'] = self.input_divide
variables['input_subtract'] = self.input_subtract
self.saver = tf.train.Saver(variables)
self.saver.restore(session, path)
"""
bob/learn/tensorflow/trainers/SiameseTrainer.py
View file @
4402ab49
...
...
@@ -92,6 +92,7 @@ class SiameseTrainer(object):
print
(
"Initializing !!"
)
# Training
hdf5
=
bob
.
io
.
base
.
HDF5File
(
os
.
path
.
join
(
self
.
temp_dir
,
'model.hdf5'
),
'w'
)
with
tf
.
Session
()
as
session
:
analizer
=
Analizer
(
data_shuffler
,
self
.
architecture
,
session
)
...
...
@@ -119,9 +120,9 @@ class SiameseTrainer(object):
if
step
%
self
.
snapshot
==
0
:
analizer
()
if
self
.
save_intermediate
:
self
.
architecture
.
save
(
session
,
os
.
path
.
join
(
self
.
temp_dir
,
'OUTPUT'
),
step
)
self
.
architecture
.
save
(
hdf5
,
step
)
print
str
(
step
)
+
" - "
+
str
(
analizer
.
eer
[
-
1
])
self
.
architecture
.
save
(
session
,
os
.
path
.
join
(
self
.
temp_dir
,
'OUTPUT'
))
self
.
architecture
.
save
(
hdf5
)
del
hdf5
train_writer
.
close
()
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment