Added center loss

parent b2932fdb
Pipeline #12939 failed with stages
in 3 minutes and 25 seconds
......@@ -84,6 +84,7 @@ class TFRecord(object):
# Convert the image data from string back to the numbers
image = tf.decode_raw(features['train/data'], tf.float32)
#image = tf.decode_raw(features['train/data'], tf.uint8)
# Cast label data into int32
label = tf.cast(features['train/label'], tf.int64)
......
......@@ -55,3 +55,64 @@ class MeanSoftMaxLoss(object):
return tf.add_n([loss] + regularization_losses, name='total_loss')
else:
return loss
class MeanSoftMaxLossCenterLoss(object):
"""
Mean softmax loss. Basically it wrapps the function tf.nn.sparse_softmax_cross_entropy_with_logits.
"""
def __init__(self, name="loss", add_regularization_losses=True, alpha=0.9, factor=0.01, n_classes=10):
"""
Constructor
**Parameters**
name:
Scope name
"""
self.name = name
self.add_regularization_losses = add_regularization_losses
self.n_classes = n_classes
self.alpha = alpha
self.factor = factor
def append_center_loss(self, features, label):
nrof_features = features.get_shape()[1]
centers = tf.get_variable('centers', [self.n_classes, nrof_features], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label = tf.reshape(label, [-1])
centers_batch = tf.gather(centers, label)
diff = (1 - self.alpha) * (centers_batch - features)
centers = tf.scatter_sub(centers, label, diff)
loss = tf.reduce_mean(tf.square(features - centers_batch))
return loss
def __call__(self, logits_prelogits, label):
#TODO: Test the dictionary
logits = logits_prelogits['logits']
# Cross entropy
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=label), name=self.name)
# Appending center loss
prelogits = logits_prelogits['prelogits']
center_loss = self.append_center_loss(prelogits, label)
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * self.factor)
# Adding the regularizers in the loss
if self.add_regularization_losses:
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
loss = tf.add_n([loss] + regularization_losses, name='total_loss')
return loss
from .BaseLoss import BaseLoss, MeanSoftMaxLoss
from .BaseLoss import BaseLoss, MeanSoftMaxLoss, MeanSoftMaxLossCenterLoss
from .ContrastiveLoss import ContrastiveLoss
from .TripletLoss import TripletLoss
from .TripletAverageLoss import TripletAverageLoss
......
......@@ -78,6 +78,7 @@ def main():
tf.reset_default_graph() if os.path.exists(output_dir) else None
# Run validation with embeddings
validate_with_embeddings = False
if hasattr(config, 'validate_with_embeddings'):
validate_with_embeddings = config.validate_with_embeddings
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 13 Oct 2016 13:35 CEST
import numpy
from bob.learn.tensorflow.datashuffler import TFRecord
from bob.learn.tensorflow.loss import MeanSoftMaxLossCenterLoss, MeanSoftMaxLoss
from bob.learn.tensorflow.trainers import Trainer, constant
from bob.learn.tensorflow.utils import load_mnist
import tensorflow as tf
import shutil
import os
"""
Some unit tests that create networks on the fly
"""
batch_size = 16
validation_batch_size = 400
iterations = 200
seed = 10
directory = "./temp/cnn_scratch"
slim = tf.contrib.slim
def scratch_network_embeding_example(train_data_shuffler, reuse=False, get_embedding=False):
if isinstance(train_data_shuffler, tf.Tensor):
inputs = train_data_shuffler
else:
inputs = train_data_shuffler("data", from_queue=False)
# Creating a random network
initializer = tf.contrib.layers.xavier_initializer(seed=seed)
graph = slim.conv2d(inputs, 10, [3, 3], activation_fn=tf.nn.relu, stride=1, scope='conv1',
weights_initializer=initializer, reuse=reuse)
graph = slim.max_pool2d(graph, [4, 4], scope='pool1')
graph = slim.flatten(graph, scope='flatten1')
prelogits = slim.fully_connected(graph, 30, activation_fn=None, scope='fc1',
weights_initializer=initializer, reuse=reuse)
if get_embedding:
embedding = tf.nn.l2_normalize(prelogits, dim=1, name="embedding")
return embedding
else:
logits = slim.fully_connected(prelogits, 10, activation_fn=None, scope='logits',
weights_initializer=initializer, reuse=reuse)
logits_prelogits = dict()
logits_prelogits['logits'] = logits
logits_prelogits['prelogits'] = prelogits
return logits_prelogits
def test_cnn_tfrecord_embedding_validation():
tf.reset_default_graph()
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = train_data.astype("float32") * 0.00390625
validation_data = validation_data.astype("float32") * 0.00390625
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def create_tf_record(tfrecords_filename, data, labels):
writer = tf.python_io.TFRecordWriter(tfrecords_filename)
#for i in range(train_data.shape[0]):
for i in range(6000):
img = data[i]
img_raw = img.tostring()
feature = {'train/data': _bytes_feature(img_raw),
'train/label': _int64_feature(labels[i])
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
writer.close()
tf.reset_default_graph()
# Creating the tf record
tfrecords_filename = "mnist_train.tfrecords"
create_tf_record(tfrecords_filename, train_data, train_labels)
filename_queue = tf.train.string_input_producer([tfrecords_filename], num_epochs=55, name="input")
tfrecords_filename_val = "mnist_validation.tfrecords"
create_tf_record(tfrecords_filename_val, validation_data, validation_labels)
filename_queue_val = tf.train.string_input_producer([tfrecords_filename_val], num_epochs=55, name="input_validation")
# Creating the CNN using the TFRecord as input
train_data_shuffler = TFRecord(filename_queue=filename_queue,
batch_size=batch_size)
validation_data_shuffler = TFRecord(filename_queue=filename_queue_val,
batch_size=2000)
graph = scratch_network_embeding_example(train_data_shuffler)
validation_graph = scratch_network_embeding_example(validation_data_shuffler, reuse=True, get_embedding=True)
# Setting the placeholders
# Loss for the softmax
loss = MeanSoftMaxLossCenterLoss(n_classes=10, add_regularization_losses=False)
# One graph trainer
trainer = Trainer(train_data_shuffler,
validation_data_shuffler=validation_data_shuffler,
validate_with_embeddings=True,
iterations=iterations, #It is supper fast
analizer=None,
temp_dir=directory)
learning_rate = constant(0.01, name="regular_lr")
trainer.create_network_from_scratch(graph=graph,
validation_graph=validation_graph,
loss=loss,
learning_rate=learning_rate,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
)
trainer.train()
os.remove(tfrecords_filename)
os.remove(tfrecords_filename_val)
assert True
tf.reset_default_graph()
del trainer
assert len(tf.global_variables())==0
# Inference. TODO: Wrap this in a package
file_name = os.path.join(directory, "model.ckp.meta")
images = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
graph = scratch_network_embeding_example(images, reuse=False)
session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.import_meta_graph(file_name, clear_devices=True)
saver.restore(session, tf.train.latest_checkpoint(os.path.dirname("./temp/cnn_scratch/")))
data = numpy.random.rand(2, 28, 28, 1).astype("float32")
assert session.run(graph['logits'], feed_dict={images: data}).shape == (2, 10)
tf.reset_default_graph()
shutil.rmtree(directory)
assert len(tf.global_variables())==0
......@@ -50,6 +50,7 @@ class SiameseTrainer(Trainer):
def __init__(self,
train_data_shuffler,
validation_data_shuffler=None,
validate_with_embeddings=False,
###### training options ##########
iterations=5000,
......@@ -84,6 +85,7 @@ class SiameseTrainer(Trainer):
self.validation_summary_writter = None
self.summaries_validation = None
self.validation_data_shuffler = validation_data_shuffler
self.validate_with_embeddings = validate_with_embeddings
# Analizer
self.analizer = analizer
......
......@@ -256,7 +256,13 @@ class Trainer(object):
# SAving some variables
tf.add_to_collection("global_step", self.global_step)
tf.add_to_collection("graph", self.graph)
if isinstance(self.graph, dict):
tf.add_to_collection("graph", self.graph['logits'])
tf.add_to_collection("prelogits", self.graph['prelogits'])
else:
tf.add_to_collection("graph", self.graph)
tf.add_to_collection("predictor", self.predictor)
tf.add_to_collection("data_ph", self.data_ph)
......@@ -445,7 +451,11 @@ class Trainer(object):
tf.summary.scalar('lr', self.learning_rate)
# Computing accuracy
correct_prediction = tf.equal(tf.argmax(output, 1), label)
if isinstance(output, dict):
correct_prediction = tf.equal(tf.argmax(output['logits'], 1), label)
else:
correct_prediction = tf.equal(tf.argmax(output, 1), label)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', accuracy)
return tf.summary.merge_all()
......
......@@ -51,6 +51,7 @@ class TripletTrainer(Trainer):
def __init__(self,
train_data_shuffler,
validation_data_shuffler=None,
validate_with_embeddings=False,
###### training options ##########
iterations=5000,
......@@ -85,6 +86,7 @@ class TripletTrainer(Trainer):
self.validation_summary_writter = None
self.summaries_validation = None
self.validation_data_shuffler = validation_data_shuffler
self.validate_with_embeddings = validate_with_embeddings
# Analizer
self.analizer = analizer
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment