diff --git a/bob/learn/tensorflow/test/test_utils.py b/bob/learn/tensorflow/test/test_utils.py index c256f71736efb9c39ff67c6d5b6f527dc4d830e1..966fd66c0036ce00feba0b9c82d411b8615b0f26 100755 --- a/bob/learn/tensorflow/test/test_utils.py +++ b/bob/learn/tensorflow/test/test_utils.py @@ -3,7 +3,9 @@ # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> import numpy -from bob.learn.tensorflow.utils import compute_embedding_accuracy +from bob.learn.tensorflow.utils import compute_embedding_accuracy, cdist, compute_embedding_accuracy_tensors + +import tensorflow as tf """ Some unit tests for the datashuffler @@ -33,4 +35,26 @@ def test_embedding_accuracy(): labels = numpy.concatenate((labels, noise_labels)) assert compute_embedding_accuracy(data, labels) == 10 / 15. + + +def test_embedding_accuracy_tensors(): + + numpy.random.seed(10) + samples_per_class = 5 + + class_a = numpy.random.normal(loc=0, scale=0.1, size=(samples_per_class, 2)) + labels_a = numpy.zeros(samples_per_class) + class_b = numpy.random.normal(loc=10, scale=0.1, size=(samples_per_class, 2)) + labels_b = numpy.ones(samples_per_class) + + data = numpy.vstack((class_a, class_b)) + labels = numpy.concatenate((labels_a, labels_b)) + + data = tf.convert_to_tensor(data.astype("float32")) + labels = tf.convert_to_tensor(labels.astype("int64")) + + sess = tf.Session() + accuracy = sess.run(compute_embedding_accuracy_tensors(data, labels)) + assert accuracy == 1. + diff --git a/bob/learn/tensorflow/trainers/__init__.py b/bob/learn/tensorflow/trainers/__init__.py index ee0a98197281f67b7defe48a59d91ff672bae6fa..6f22f211e205f0f4a7a5605db261382673ef4025 100755 --- a/bob/learn/tensorflow/trainers/__init__.py +++ b/bob/learn/tensorflow/trainers/__init__.py @@ -3,6 +3,7 @@ from .Trainer import Trainer from .SiameseTrainer import SiameseTrainer from .TripletTrainer import TripletTrainer from .learning_rate import exponential_decay, constant +from .LogitsTrainer import LogitsTrainer import numpy diff --git a/bob/learn/tensorflow/utils/util.py b/bob/learn/tensorflow/utils/util.py index 982bcb56b813c7944f22e0ee81b371aef1c24bc8..7aae42fa20e19b5aebad1748d809350ec5156e4c 100755 --- a/bob/learn/tensorflow/utils/util.py +++ b/bob/learn/tensorflow/utils/util.py @@ -37,95 +37,37 @@ def load_mnist(perc_train=0.9): n_train = int(perc_train*indexes.shape[0]) n_validation = total_samples - n_train - train_data = data[0:n_train, :] + train_data = data[0:n_train, :].astype("float32") * 0.00390625 train_labels = labels[0:n_train] - validation_data = data[n_train:n_train+n_validation, :] + validation_data = data[n_train:n_train+n_validation, :].astype("float32") * 0.00390625 validation_labels = labels[n_train:n_train+n_validation] return train_data, train_labels, validation_data, validation_labels -def plot_embedding_pca(features, labels): - """ - - Trains a PCA using bob, reducing the features to dimension 2 and plot it the possible clusters - - :param features: - :param labels: - :return: - """ - - import bob.learn.linear - import matplotlib.pyplot as mpl - - colors = ['#FF0000', '#FFFF00', '#FF00FF', '#00FFFF', '#000000', - '#AA0000', '#AAAA00', '#AA00AA', '#00AAAA', '#330000'] - - # Training PCA - trainer = bob.learn.linear.PCATrainer() - machine, lamb = trainer.train(features.astype("float64")) - - # Getting the first two most relevant features - projected_features = machine(features.astype("float64"))[:, 0:2] - - # Plotting the classes - n_classes = max(labels)+1 - fig = mpl.figure() - - for i in range(n_classes): - indexes = numpy.where(labels == i)[0] - - selected_features = projected_features[indexes,:] - mpl.scatter(selected_features[:, 0], selected_features[:, 1], - marker='.', c=colors[i], linewidths=0, label=str(i)) - mpl.legend() - return fig - -def plot_embedding_lda(features, labels): - """ - - Trains a LDA using bob, reducing the features to dimension 2 and plot it the possible clusters - - :param features: - :param labels: - :return: - """ - - import bob.learn.linear - import matplotlib.pyplot as mpl - - colors = ['#FF0000', '#FFFF00', '#FF00FF', '#00FFFF', '#000000', - '#AA0000', '#AAAA00', '#AA00AA', '#00AAAA', '#330000'] - n_classes = max(labels)+1 +def create_mnist_tfrecord(tfrecords_filename, data, labels, n_samples=6000): - # Training PCA - trainer = bob.learn.linear.FisherLDATrainer(use_pinv=True) - lda_features = [] - for i in range(n_classes): - indexes = numpy.where(labels == i)[0] - lda_features.append(features[indexes, :].astype("float64")) - - machine, lamb = trainer.train(lda_features) - - #import ipdb; ipdb.set_trace(); - - - # Getting the first two most relevant features - projected_features = machine(features.astype("float64"))[:, 0:2] + def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) - # Plotting the classes - fig = mpl.figure() + def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) - for i in range(n_classes): - indexes = numpy.where(labels == i)[0] + writer = tf.python_io.TFRecordWriter(tfrecords_filename) - selected_features = projected_features[indexes,:] - mpl.scatter(selected_features[:, 0], selected_features[:, 1], - marker='.', c=colors[i], linewidths=0, label=str(i)) - mpl.legend() - return fig + for i in range(n_samples): + img = data[i] + img_raw = img.tostring() + + feature = {'train/data': _bytes_feature(img_raw), + 'train/label': _int64_feature(labels[i]) + } + + example = tf.train.Example(features=tf.train.Features(feature=feature)) + writer.write(example.SerializeToString()) + writer.close() def compute_eer(data_train, labels_train, data_validation, labels_validation, n_classes): @@ -209,6 +151,41 @@ def debug_embbeding(image, architecture, embbeding_dim=2, feature_layer="fc3"): return embeddings + + +def cdist(A): + with tf.variable_scope('Pairwisedistance'): + p1 = tf.matmul( + tf.expand_dims(tf.reduce_sum(tf.square(A), 1), 1), + tf.ones(shape=(1, A.shape.as_list()[0])) + ) + p2 = tf.transpose(tf.matmul( + tf.reshape(tf.reduce_sum(tf.square(A), 1), shape=[-1, 1]), + tf.ones(shape=(A.shape.as_list()[0], 1)), + transpose_b=True + )) + + return tf.sqrt(tf.add(p1, p2) - 2 * tf.matmul(A, A, transpose_b=True)) + + +def compute_embedding_accuracy_tensors(embedding, labels): + """ + Compute the accuracy through exhaustive comparisons between the embeddings using tensors + """ + + distances = cdist(embedding) + + # Fitting the main diagonal with infs (removing comparisons with the same sample) + inf = numpy.ones(10)*numpy.inf + inf = inf.astype("float32") + + distances = cdist(embedding) + distances = tf.matrix_set_diag(distances, inf) + indexes = tf.argmin(distances, axis=1) + + matching = [ tf.equal(labels[i],labels[j]) for i,j in zip(range(indexes.get_shape().as_list()[0]), tf.unstack(indexes))] + return tf.reduce_sum(tf.cast(matching, tf.uint8))/indexes.get_shape().as_list()[0] + def compute_embedding_accuracy(embedding, labels): """ @@ -224,7 +201,7 @@ def compute_embedding_accuracy(embedding, labels): # Fitting the main diagonal with infs (removing comparisons with the same sample) numpy.fill_diagonal(distances, numpy.inf) - indexes = distances.argmin(axis=1) + indexes = distances.argmin(axis=1) # Computing the argmin excluding comparisons with the same samples # Basically, we are excluding the main diagonal