Implemented embedding validation

parent 79cb3334
......@@ -260,5 +260,6 @@ def inception_resnet_v2(inputs, is_training=True,
net = slim.fully_connected(net, bottleneck_layer_size, activation_fn=None,
scope='Bottleneck', reuse=False)
return net, end_points
......@@ -45,6 +45,33 @@ def scratch_network(train_data_shuffler, reuse=False):
return graph
def scratch_network_embeding_example(train_data_shuffler, reuse=False, get_embedding=False):
if isinstance(train_data_shuffler, tf.Tensor):
inputs = train_data_shuffler
else:
inputs = train_data_shuffler("data", from_queue=False)
# Creating a random network
initializer = tf.contrib.layers.xavier_initializer(seed=seed)
graph = slim.conv2d(inputs, 10, [3, 3], activation_fn=tf.nn.relu, stride=1, scope='conv1',
weights_initializer=initializer, reuse=reuse)
graph = slim.max_pool2d(graph, [4, 4], scope='pool1')
graph = slim.flatten(graph, scope='flatten1')
graph = slim.fully_connected(graph, 30, activation_fn=None, scope='fc1',
weights_initializer=initializer, reuse=reuse)
if get_embedding:
graph = tf.nn.l2_normalize(graph, dim=1, name="embedding")
else:
graph = slim.fully_connected(graph, 10, activation_fn=None, scope='logits',
weights_initializer=initializer, reuse=reuse)
return graph
def validate_network(embedding, validation_data, validation_labels, input_shape=[None, 28, 28, 1], normalizer=ScaleFactor()):
# Testing
validation_data_shuffler = Memory(validation_data, validation_labels,
......@@ -104,7 +131,7 @@ def test_cnn_trainer_scratch():
assert len(tf.global_variables())==0
def test_cnn_trainer_scratch_tfrecord():
def test_cnn_tfrecord():
tf.reset_default_graph()
train_data, train_labels, validation_data, validation_labels = load_mnist()
......@@ -201,3 +228,101 @@ def test_cnn_trainer_scratch_tfrecord():
assert len(tf.global_variables())==0
def test_cnn_tfrecord_embedding_validation():
tf.reset_default_graph()
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = train_data.astype("float32") * 0.00390625
validation_data = validation_data.astype("float32") * 0.00390625
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def create_tf_record(tfrecords_filename, data, labels):
writer = tf.python_io.TFRecordWriter(tfrecords_filename)
#for i in range(train_data.shape[0]):
for i in range(6000):
img = data[i]
img_raw = img.tostring()
feature = {'train/data': _bytes_feature(img_raw),
'train/label': _int64_feature(labels[i])
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
writer.close()
tf.reset_default_graph()
# Creating the tf record
tfrecords_filename = "mnist_train.tfrecords"
create_tf_record(tfrecords_filename, train_data, train_labels)
filename_queue = tf.train.string_input_producer([tfrecords_filename], num_epochs=55, name="input")
tfrecords_filename_val = "mnist_validation.tfrecords"
create_tf_record(tfrecords_filename_val, validation_data, validation_labels)
filename_queue_val = tf.train.string_input_producer([tfrecords_filename_val], num_epochs=55, name="input_validation")
# Creating the CNN using the TFRecord as input
train_data_shuffler = TFRecord(filename_queue=filename_queue,
batch_size=batch_size)
validation_data_shuffler = TFRecord(filename_queue=filename_queue_val,
batch_size=2000)
graph = scratch_network_embeding_example(train_data_shuffler)
validation_graph = scratch_network_embeding_example(validation_data_shuffler, reuse=True, get_embedding=True)
# Setting the placeholders
# Loss for the softmax
loss = MeanSoftMaxLoss()
# One graph trainer
trainer = Trainer(train_data_shuffler,
validation_data_shuffler=validation_data_shuffler,
validate_with_embeddings=True,
iterations=iterations, #It is supper fast
analizer=None,
temp_dir=directory)
learning_rate = constant(0.01, name="regular_lr")
trainer.create_network_from_scratch(graph=graph,
validation_graph=validation_graph,
loss=loss,
learning_rate=learning_rate,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
)
trainer.train()
os.remove(tfrecords_filename)
os.remove(tfrecords_filename_val)
assert True
tf.reset_default_graph()
del trainer
assert len(tf.global_variables())==0
# Inference. TODO: Wrap this in a package
file_name = os.path.join(directory, "model.ckp.meta")
images = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
graph = scratch_network_embeding_example(images, reuse=False)
session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.import_meta_graph(file_name, clear_devices=True)
saver.restore(session, tf.train.latest_checkpoint(os.path.dirname("./temp/cnn_scratch/")))
data = numpy.random.rand(2, 28, 28, 1).astype("float32")
assert session.run(graph, feed_dict={images: data}).shape == (2, 10)
tf.reset_default_graph()
shutil.rmtree(directory)
assert len(tf.global_variables())==0
......@@ -62,3 +62,4 @@ def test_train_script_siamese():
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
......@@ -13,6 +13,7 @@ from tensorflow.core.framework import summary_pb2
import time
from bob.learn.tensorflow.datashuffler import OnlineSampling, TFRecord
from bob.learn.tensorflow.utils.session import Session
from bob.learn.tensorflow.utils import compute_embedding_accuracy
from .learning_rate import constant
import time
......@@ -54,6 +55,7 @@ class Trainer(object):
def __init__(self,
train_data_shuffler,
validation_data_shuffler=None,
validate_with_embeddings=False,
###### training options ##########
iterations=5000,
......@@ -100,7 +102,8 @@ class Trainer(object):
self.loss = None
self.predictor = None
self.validation_predictor = None
self.validation_predictor = None
self.validate_with_embeddings = validate_with_embeddings
self.optimizer_class = None
self.learning_rate = None
......@@ -168,7 +171,10 @@ class Trainer(object):
# Running validation
if self.validation_data_shuffler is not None and step % self.validation_snapshot == 0:
self.compute_validation(step)
if self.validate_with_embeddings:
self.compute_validation_embeddings(step)
else:
self.compute_validation(step)
# Taking snapshot
if step % self.snapshot == 0:
......@@ -178,7 +184,11 @@ class Trainer(object):
# Running validation for the last time
if self.validation_data_shuffler is not None:
self.compute_validation(step)
if self.validate_with_embeddings:
self.compute_validation_embeddings(step)
else:
self.compute_validation(step)
logger.info("Training finally finished")
......@@ -264,7 +274,10 @@ class Trainer(object):
self.validation_graph = validation_graph
self.validation_predictor = self.loss(self.validation_graph, self.validation_label_ph)
if self.validate_with_embeddings:
self.validation_predictor = self.validation_graph
else:
self.validation_predictor = self.loss(self.validation_graph, self.validation_label_ph)
self.summaries_validation = self.create_general_summary(self.validation_predictor, self.validation_graph, self.validation_label_ph)
tf.add_to_collection("summaries_validation", self.summaries_validation)
......@@ -398,6 +411,30 @@ class Trainer(object):
logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
self.validation_summary_writter.add_summary(summary, step)
def compute_validation_embeddings(self, step):
"""
Computes the loss in the validation set with embeddings
** Parameters **
session: Tensorflow session
data_shuffler: The data shuffler to be used
step: Iteration number
"""
if self.validation_data_shuffler.prefetch:
embedding, labels = self.session.run([self.validation_predictor, self.validation_label_ph])
else:
feed_dict = self.get_feed_dict(self.validation_data_shuffler)
embedding, labels = self.session.run([self.validation_predictor, self.validation_label_ph],
feed_dict=feed_dict)
accuracy = compute_embedding_accuracy(embedding, labels)
summary = summary_pb2.Summary.Value(tag="accuracy", simple_value=accuracy)
logger.info("VALIDATION Accuracy set step={0} = {1}".format(step, accuracy))
self.validation_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)
def create_general_summary(self, average_loss, output, label):
"""
Creates a simple tensorboard summary with the value of the loss and learning rate
......
......@@ -208,4 +208,28 @@ def debug_embbeding(image, architecture, embbeding_dim=2, feature_layer="fc3"):
embeddings[i] = embedding
return embeddings
def compute_embedding_accuracy(embedding, labels):
"""
Compute the accuracy through exhaustive comparisons between the embeddings
"""
from scipy.spatial.distance import cdist
distances = cdist(embedding, embedding)
n_samples = embedding.shape[0]
# Computing the argmin excluding comparisons with the same samples
# Basically, we are excluding the main diagonal
valid_indexes = distances[distances>0].reshape(n_samples, n_samples-1).argmin(axis=1)
# Getting the original positions of the indexes in the 1-axis
corrected_indexes = [ i if i<j else i+1 for i, j in zip(valid_indexes, range(n_samples))]
matching = [ labels[i]==labels[j] for i,j in zip(range(n_samples), corrected_indexes)]
accuracy = sum(matching)/float(n_samples)
return accuracy
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment