Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
1952c2f5
Commit
1952c2f5
authored
Sep 19, 2017
by
Tiago de Freitas Pereira
Browse files
New test cases
parent
5a60ec56
Changes
2
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/test/test_cnn_pretrained_model.py
View file @
1952c2f5
...
...
@@ -256,4 +256,5 @@ def test_siamese_cnn_pretrained():
del
trainer
tf
.
reset_default_graph
()
assert
len
(
tf
.
global_variables
())
==
0
assert
len
(
tf
.
global_variables
())
==
0
bob/learn/tensorflow/test/test_cnn_scratch.py
View file @
1952c2f5
...
...
@@ -28,7 +28,10 @@ slim = tf.contrib.slim
def
scratch_network
(
train_data_shuffler
,
reuse
=
False
):
inputs
=
train_data_shuffler
(
"data"
,
from_queue
=
False
)
if
isinstance
(
train_data_shuffler
,
tf
.
Tensor
):
inputs
=
train_data_shuffler
else
:
inputs
=
train_data_shuffler
(
"data"
,
from_queue
=
False
)
# Creating a random network
initializer
=
tf
.
contrib
.
layers
.
xavier_initializer
(
seed
=
seed
)
...
...
@@ -150,7 +153,7 @@ def test_cnn_trainer_scratch_tfrecord():
graph
=
scratch_network
(
train_data_shuffler
)
validation_graph
=
scratch_network
(
validation_data_shuffler
,
reuse
=
True
)
# Setting the placeholders
# Loss for the softmax
loss
=
MeanSoftMaxLoss
()
...
...
@@ -174,9 +177,27 @@ def test_cnn_trainer_scratch_tfrecord():
trainer
.
train
()
os
.
remove
(
tfrecords_filename
)
os
.
remove
(
tfrecords_filename_val
)
os
.
remove
(
tfrecords_filename_val
)
assert
True
tf
.
reset_default_graph
()
del
trainer
assert
len
(
tf
.
global_variables
())
==
0
# Inference. TODO: Wrap this in a package
file_name
=
os
.
path
.
join
(
directory
,
"model.ckp.meta"
)
images
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
28
,
28
,
1
))
graph
=
scratch_network
(
images
,
reuse
=
False
)
session
=
tf
.
Session
()
session
.
run
(
tf
.
global_variables_initializer
())
saver
=
tf
.
train
.
import_meta_graph
(
file_name
,
clear_devices
=
True
)
saver
.
restore
(
session
,
tf
.
train
.
latest_checkpoint
(
os
.
path
.
dirname
(
"./temp/cnn_scratch/"
)))
data
=
numpy
.
random
.
rand
(
2
,
28
,
28
,
1
).
astype
(
"float32"
)
assert
session
.
run
(
graph
,
feed_dict
=
{
images
:
data
}).
shape
==
(
2
,
10
)
tf
.
reset_default_graph
()
shutil
.
rmtree
(
directory
)
assert
len
(
tf
.
global_variables
())
==
0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment