Commit ccb144b1 authored by Tiago Pereira's avatar Tiago Pereira

Fixed issue with the prefetch #27

parent 89da1f70
Pipeline #11261 passed with stages
in 24 minutes and 58 seconds
......@@ -56,7 +56,8 @@ class Base(object):
data_augmentation=None,
normalizer=Linear(),
prefetch=False,
prefetch_capacity=10):
prefetch_capacity=50,
prefetch_threads=5):
# Setting the seed for the pseudo random number generator
self.seed = seed
......@@ -90,6 +91,7 @@ class Base(object):
# Prefetch variables
self.prefetch = prefetch
self.prefetch_capacity = prefetch_capacity
self.prefetch_threads = prefetch_threads
self.data_ph_from_queue = None
self.label_ph_from_queue = None
......
......@@ -47,7 +47,11 @@ class Memory(Base):
batch_size=1,
seed=10,
data_augmentation=None,
normalizer=Linear()):
normalizer=Linear(),
prefetch=False,
prefetch_capacity=10,
prefetch_threads=5
):
super(Memory, self).__init__(
data=data,
......@@ -57,7 +61,10 @@ class Memory(Base):
batch_size=batch_size,
seed=seed,
data_augmentation=data_augmentation,
normalizer=normalizer
normalizer=normalizer,
prefetch=prefetch,
prefetch_capacity=prefetch_capacity,
prefetch_threads=prefetch_threads
)
# Seting the seed
numpy.random.seed(seed)
......
......@@ -46,7 +46,7 @@ def main():
BASE_PATH = args['<base_path>']
EXTENSION = args['--extension']
SHAPE = [3, 250, 250]
SHAPE = [1, 224, 224]
count, sum_data = process_images(BASE_PATH, EXTENSION, SHAPE)
......@@ -54,4 +54,5 @@ def main():
for s in range(SHAPE[0]):
means[s, ...] = sum_data[s, ...] / float(count)
bob.io.base.save(means, "means.hdf5")
bob.io.base.save(means, "means_casia.hdf5")
bob.io.base.save(means[0, :, :].astype("uint8"), "means_casia.png")
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 13 Oct 2016 13:35 CEST
import numpy
from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, ImageAugmentation, ScaleFactor
from bob.learn.tensorflow.network import Chopra
from bob.learn.tensorflow.loss import BaseLoss, ContrastiveLoss, TripletLoss
from bob.learn.tensorflow.trainers import Trainer, SiameseTrainer, TripletTrainer, constant
from .test_cnn_scratch import validate_network
from bob.learn.tensorflow.network import Embedding
from bob.learn.tensorflow.utils import load_mnist
import tensorflow as tf
import bob.io.base
import shutil
from scipy.spatial.distance import cosine
import bob.measure
"""
Some unit tests for the datashuffler
"""
batch_size = 32
validation_batch_size = 400
iterations = 300
seed = 10
def test_cnn_trainer():
# Loading data
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
# Creating datashufflers
data_augmentation = ImageAugmentation()
train_data_shuffler = Memory(train_data, train_labels,
input_shape=[None, 28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation,
normalizer=ScaleFactor(),
prefetch=True,
prefetch_threads=1)
directory = "./temp/cnn"
# Loss for the softmax
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
# Preparing the architecture
architecture = Chopra(seed=seed,
fc1_output=10)
input_pl = train_data_shuffler("data", from_queue=True)
graph = architecture(input_pl)
embedding = Embedding(train_data_shuffler("data", from_queue=False), architecture(train_data_shuffler("data", from_queue=False), reuse=True))
# One graph trainer
trainer = Trainer(train_data_shuffler,
iterations=iterations,
analizer=None,
temp_dir=directory
)
trainer.create_network_from_scratch(graph=graph,
loss=loss,
learning_rate=constant(0.01, name="regular_lr"),
optimizer=tf.train.GradientDescentOptimizer(0.01),
)
trainer.train()
#trainer.train(validation_data_shuffler)
# Using embedding to compute the accuracy
accuracy = validate_network(embedding, validation_data, validation_labels)
# At least 80% of accuracy
assert accuracy > 80.
shutil.rmtree(directory)
del trainer
del graph
del embedding
tf.reset_default_graph()
......@@ -56,7 +56,8 @@ def validate_network(embedding, validation_data, validation_labels):
def test_cnn_trainer_scratch():
tf.reset_default_graph()
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
......
......@@ -243,3 +243,23 @@ class SiameseTrainer(Trainer):
#summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
#self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
def load_and_enqueue(self):
"""
Injecting data in the place holder queue
**Parameters**
session: Tensorflow session
"""
while not self.thread_pool.should_stop():
[train_data, train_labels] = self.train_data_shuffler.get_batch()
data_ph = self.train_data_shuffler("data", from_queue=False)
label_ph = self.train_data_shuffler("label", from_queue=False)
feed_dict = {data_ph: train_data,
label_ph: train_labels}
self.session.run(self.train_data_shuffler.enqueue_op, feed_dict=feed_dict)
......@@ -129,11 +129,11 @@ class Trainer(object):
learning_rate: Learning rate
"""
self.data_ph = self.train_data_shuffler("data")
self.label_ph = self.train_data_shuffler("label")
self.data_ph = self.train_data_shuffler("data", from_queue=True)
self.label_ph = self.train_data_shuffler("label", from_queue=True)
self.graph = graph
self.loss = loss
self.predictor = self.loss(self.graph, self.train_data_shuffler("label", from_queue=False))
self.predictor = self.loss(self.graph, self.label_ph)
self.optimizer_class = optimizer
self.learning_rate = learning_rate
......@@ -272,7 +272,7 @@ class Trainer(object):
"""
threads = []
for n in range(3):
for n in range(self.train_data_shuffler.prefetch_threads):
t = threading.Thread(target=self.load_and_enqueue, args=())
t.daemon = True # thread will close when parent quits
t.start()
......@@ -290,10 +290,13 @@ class Trainer(object):
while not self.thread_pool.should_stop():
[train_data, train_labels] = self.train_data_shuffler.get_batch()
feed_dict = {self.data_ph: train_data,
self.label_ph: train_labels}
data_ph = self.train_data_shuffler("data", from_queue=False)
label_ph = self.train_data_shuffler("label", from_queue=False)
self.session.run(self.inputs.enqueue_op, feed_dict=feed_dict)
feed_dict = {data_ph: train_data,
label_ph: train_labels}
self.session.run(self.train_data_shuffler.enqueue_op, feed_dict=feed_dict)
def train(self, validation_data_shuffler=None):
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment