Commit 1952c2f5 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

New test cases

parent 5a60ec56
......@@ -256,4 +256,5 @@ def test_siamese_cnn_pretrained():
del trainer
tf.reset_default_graph()
assert len(tf.global_variables())==0
assert len(tf.global_variables())==0
......@@ -28,7 +28,10 @@ slim = tf.contrib.slim
def scratch_network(train_data_shuffler, reuse=False):
inputs = train_data_shuffler("data", from_queue=False)
if isinstance(train_data_shuffler, tf.Tensor):
inputs = train_data_shuffler
else:
inputs = train_data_shuffler("data", from_queue=False)
# Creating a random network
initializer = tf.contrib.layers.xavier_initializer(seed=seed)
......@@ -150,7 +153,7 @@ def test_cnn_trainer_scratch_tfrecord():
graph = scratch_network(train_data_shuffler)
validation_graph = scratch_network(validation_data_shuffler, reuse=True)
# Setting the placeholders
# Loss for the softmax
loss = MeanSoftMaxLoss()
......@@ -174,9 +177,27 @@ def test_cnn_trainer_scratch_tfrecord():
trainer.train()
os.remove(tfrecords_filename)
os.remove(tfrecords_filename_val)
os.remove(tfrecords_filename_val)
assert True
tf.reset_default_graph()
del trainer
assert len(tf.global_variables())==0
# Inference. TODO: Wrap this in a package
file_name = os.path.join(directory, "model.ckp.meta")
images = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
graph = scratch_network(images, reuse=False)
session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.import_meta_graph(file_name, clear_devices=True)
saver.restore(session, tf.train.latest_checkpoint(os.path.dirname("./temp/cnn_scratch/")))
data = numpy.random.rand(2, 28, 28, 1).astype("float32")
assert session.run(graph, feed_dict={images: data}).shape == (2, 10)
tf.reset_default_graph()
shutil.rmtree(directory)
assert len(tf.global_variables())==0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment