Fixed some unit tests

parent c87be09c
model_checkpoint_path: "model.ckp"
all_model_checkpoint_paths: "model.ckp"
ccollections
OrderedDict
p0
((lp1
(lp2
S'conv1'
p3
accopy_reg
_reconstructor
p4
(cbob.learn.tensorflow.layers
Conv2D
p5
c__builtin__
object
p6
Ntp7
Rp8
(dp9
S'batch_var'
p10
NsS'name'
p11
g3
sS'filters'
p12
I10
sS'use_gpu'
p13
I00
sS'activation'
p14
ctensorflow.python.ops.math_ops
tanh
p15
sS'W'
p16
NsS'stride'
p17
(lp18
I1
aI1
aI1
aI1
asS'beta'
p19
NsS'b'
p20
NsS'weights_initialization'
p21
g4
(cbob.learn.tensorflow.initialization
Xavier
p22
g6
Ntp23
Rp24
(dp25
g13
I00
sS'seed'
p26
F10.0
sbsS'input_layer'
p27
NsS'batch_mean'
p28
NsS'bias_initialization'
p29
g4
(cbob.learn.tensorflow.initialization
Constant
p30
g6
Ntp31
Rp32
(dp33
g13
I00
sg26
NsS'constant_value'
p34
F0.1
sbsS'kernel_size'
p35
I3
sS'gamma'
p36
NsS'batch_norm'
p37
I00
sbaa(lp38
S'fc1'
p39
ag4
(cbob.learn.tensorflow.layers
FullyConnected
p40
g6
Ntp41
Rp42
(dp43
g10
Nsg11
g39
sg13
I00
sg14
NsS'shape'
p44
Nsg16
Nsg19
Nsg20
Nsg21
g4
(g22
g6
Ntp45
Rp46
(dp47
g13
I00
sg26
F10.0
sbsg27
Nsg28
Nsg29
g4
(g30
g6
Ntp48
Rp49
(dp50
g13
I00
sg26
Nsg34
F0.1
sbsS'output_dim'
p51
I10
sg36
Nsg37
I00
sbaatp52
Rp53
.
\ No newline at end of file
......@@ -14,6 +14,7 @@ import pkg_resources
from bob.learn.tensorflow.utils import load_mnist
from bob.learn.tensorflow.network import SequenceNetwork
from bob.learn.tensorflow.datashuffler import Memory
import tensorflow as tf
def validate_network(validation_data, validation_labels, network):
......@@ -28,8 +29,9 @@ def validate_network(validation_data, validation_labels, network):
return accuracy
"""
def test_load_test_cnn():
tf.reset_default_graph()
_, _, validation_data, validation_labels = load_mnist()
......@@ -41,4 +43,4 @@ def test_load_test_cnn():
accuracy = validate_network(validation_data, validation_labels, network)
assert accuracy > 80
del network
"""
......@@ -102,8 +102,6 @@ def test_cnn_pretrained():
learning_rate=constant(0.05, name="lr2"),
temp_dir=directory2,
model_from_file=os.path.join(directory, "model.ckp"))
#import ipdb; ipdb.set_trace();
trainer.train(train_data_shuffler)
accuracy = validate_network(validation_data, validation_labels, scratch)
......
......@@ -391,7 +391,7 @@ class Trainer(object):
logger.info("Loading pretrained model from {0}".format(self.model_from_file))
saver = self.bootstrap_graphs_fromfile(train_data_shuffler, validation_data_shuffler)
start_step = self.global_step.eval(self.session)
start_step = self.global_step.eval(session=self.session)
else:
start_step = 0
......@@ -400,6 +400,7 @@ class Trainer(object):
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False, name="global_step")
tf.add_to_collection("global_step", self.global_step)
# Preparing the optimizer
self.optimizer_class._learning_rate = self.learning_rate
......@@ -411,6 +412,8 @@ class Trainer(object):
self.summaries_train = self.create_general_summary()
tf.add_to_collection("summaries_train", self.summaries_train)
tf.add_to_collection("summaries_train", self.summaries_train)
tf.initialize_all_variables().run(session=self.session)
# Original tensorflow saver object
......@@ -422,8 +425,8 @@ class Trainer(object):
# Start a thread to enqueue data asynchronously, and hide I/O latency.
if self.prefetch:
self.thread_pool = tf.train.Coordinator()
tf.train.start_queue_runners(coord=self.thread_pool)
threads = self.start_thread(self.session)
tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
threads = self.start_thread()
# TENSOR BOARD SUMMARY
self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
......
......@@ -62,7 +62,6 @@ class TripletTrainer(Trainer):
"""
def __init__(self,
architecture,
optimizer=tf.train.AdamOptimizer(),
......
......@@ -43,7 +43,7 @@ setup(
# information before releasing code publicly.
name = 'bob.learn.tensorflow',
version = open("version.txt").read().rstrip(),
description = 'Hands on with tensor flow',
description = 'Bob bindings for tensorflow',
url = '',
license = 'BSD',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment