Skip to content
Snippets Groups Projects
Commit f04d1ae5 authored by Olivier Canévet's avatar Olivier Canévet
Browse files

[utils] Add load_real_mnist function

parent 7aac86e2
No related branches found
No related tags found
No related merge requests found
......@@ -17,6 +17,34 @@ def compute_euclidean_distance(x, y):
d = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(x, y)), 1))
return d
def dense_to_one_hot(labels_dense, num_classes):
"""
Convert class labels from scalars to one-hot vectors.
Taken from tensorflow/contrib/learn/python/learn/datasets/mnist.py
"""
num_labels = labels_dense.shape[0]
index_offset = numpy.arange(num_labels) * num_classes
labels_one_hot = numpy.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
def load_real_mnist(one_hot=False, data_dir=None):
"""
Return the original train and test data
"""
import bob.db.mnist
db = bob.db.mnist.Database(data_dir=data_dir)
train_data, train_labels = db.data(groups=["train"])
test_data, test_labels = db.data(groups=["test"])
if one_hot:
train_labels = dense_to_one_hot(train_labels, 10)
test_labels = dense_to_one_hot(test_labels, 10)
return train_data, train_labels, test_data, test_labels
def load_mnist(perc_train=0.9, data_dir=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment