Fixed test units

parent ccc7f48a
......@@ -8,7 +8,7 @@ from .Base import Base
from bob.learn.tensorflow.network import SequenceNetwork
class OnlineSampling(Base):
class OnlineSampling(object):
"""
This data shuffler uses the current state of the network to select the samples.
This class is not meant to be used, but extended.
......
......@@ -43,6 +43,9 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
self.inference_graph = None
self.inference_placeholder = None
def __del__(self):
tf.reset_default_graph()
def add(self, layer):
"""
Add a :py:class:`bob.learn.tensorflow.layers.Layer` in the sequence network
......
......@@ -102,7 +102,8 @@ def test_cnn_trainer():
prefetch=False,
learning_rate=constant(0.05, name="regular_lr"),
optimizer=tf.train.AdamOptimizer(name="adam_softmax"),
temp_dir=directory)
temp_dir=directory
)
trainer.train(train_data_shuffler)
......@@ -144,7 +145,8 @@ def test_siamesecnn_trainer():
analizer=None,
learning_rate=constant(0.05, name="siamese_lr"),
optimizer=tf.train.AdamOptimizer(name="adam_siamese"),
temp_dir=directory)
temp_dir=directory
)
trainer.train(train_data_shuffler)
......@@ -187,7 +189,8 @@ def test_tripletcnn_trainer():
analizer=None,
learning_rate=constant(0.05, name="triplet_lr"),
optimizer=tf.train.AdamOptimizer(name="adam_triplet"),
temp_dir=directory)
temp_dir=directory
)
trainer.train(train_data_shuffler)
......
......@@ -83,8 +83,9 @@ def test_cnn_pretrained():
prefetch=False,
learning_rate=constant(0.05, name="regular_lr"),
optimizer=tf.train.AdamOptimizer(name="adam_pretrained_model"),
temp_dir=directory)
import ipdb; ipdb.set_trace();
temp_dir=directory
)
trainer.train(train_data_shuffler)
accuracy = validate_network(validation_data, validation_labels, scratch)
assert accuracy > 85
......@@ -103,7 +104,8 @@ def test_cnn_pretrained():
prefetch=False,
learning_rate=None,
temp_dir=directory2,
model_from_file=os.path.join(directory, "model.ckp"))
model_from_file=os.path.join(directory, "model.ckp")
)
trainer.train(train_data_shuffler)
......
......@@ -79,7 +79,8 @@ def test_cnn_trainer_scratch():
iterations=iterations,
analizer=None,
prefetch=False,
temp_dir=directory)
temp_dir=directory
)
trainer.train(train_data_shuffler)
......
......@@ -59,7 +59,8 @@ def test_dnn_trainer():
prefetch=False,
learning_rate=constant(0.05, name="dnn_lr"),
optimizer=tf.train.AdamOptimizer(name="adam_dnn"),
temp_dir=directory)
temp_dir=directory
)
trainer.train(train_data_shuffler)
accuracy = validate_network(validation_data, validation_labels, architecture)
......
......@@ -56,6 +56,7 @@ class SiameseTrainer(Trainer):
verbosity_level:
"""
def __init__(self,
......@@ -80,7 +81,8 @@ class SiameseTrainer(Trainer):
model_from_file="",
verbosity_level=2):
verbosity_level=2
):
super(SiameseTrainer, self).__init__(
architecture=architecture,
......
......@@ -202,7 +202,7 @@ class Trainer(object):
if self.prefetch:
_, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
self.learning_rate, self.summaries_train])
self.learning_rate, self.summaries_train])
else:
feed_dict = self.get_feed_dict(self.train_data_shuffler)
_, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
......
......@@ -85,7 +85,8 @@ class TripletTrainer(Trainer):
model_from_file="",
verbosity_level=2):
verbosity_level=2
):
super(TripletTrainer, self).__init__(
architecture=architecture,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment