Skip to content
Snippets Groups Projects
Commit 067c701f authored by Olivier Canévet's avatar Olivier Canévet
Browse files

Add data_dir to load_mnist to avoid downloading data all the time

parent 345a39bd
Branches
No related tags found
No related merge requests found
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST
# @date: Wed 11 May 2016 09:39:36 CEST
import numpy
import tensorflow as tf
......@@ -18,10 +18,10 @@ def compute_euclidean_distance(x, y):
return d
def load_mnist(perc_train=0.9):
def load_mnist(perc_train=0.9, data_dir=None):
import bob.db.mnist
db = bob.db.mnist.Database()
db = bob.db.mnist.Database(data_dir=data_dir)
raw_data = db.data()
# data = raw_data[0].astype(numpy.float64)
......@@ -189,15 +189,15 @@ def compute_accuracy(data_train, labels_train, data_validation, labels_validatio
tp += 1
return (float(tp) / data_validation.shape[0]) * 100
def debug_embbeding(image, architecture, embbeding_dim=2, feature_layer="fc3"):
"""
"""
import tensorflow as tf
from bob.learn.tensorflow.utils.session import Session
session = Session.instance(new=False).session
session = Session.instance(new=False).session
inference_graph = architecture.compute_graph(architecture.inference_placeholder, feature_layer=feature_layer, training=False)
embeddings = numpy.zeros(shape=(image.shape[0], embbeding_dim))
......@@ -208,4 +208,3 @@ def debug_embbeding(image, architecture, embbeding_dim=2, feature_layer="fc3"):
embeddings[i] = embedding
return embeddings
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment