Commit ccb144b1 authored by Tiago Pereira's avatar Tiago Pereira

Fixed issue with the prefetch #27

parent 89da1f70
Pipeline #11261 passed with stages
in 24 minutes and 58 seconds
...@@ -56,7 +56,8 @@ class Base(object): ...@@ -56,7 +56,8 @@ class Base(object):
data_augmentation=None, data_augmentation=None,
normalizer=Linear(), normalizer=Linear(),
prefetch=False, prefetch=False,
prefetch_capacity=10): prefetch_capacity=50,
prefetch_threads=5):
# Setting the seed for the pseudo random number generator # Setting the seed for the pseudo random number generator
self.seed = seed self.seed = seed
...@@ -90,6 +91,7 @@ class Base(object): ...@@ -90,6 +91,7 @@ class Base(object):
# Prefetch variables # Prefetch variables
self.prefetch = prefetch self.prefetch = prefetch
self.prefetch_capacity = prefetch_capacity self.prefetch_capacity = prefetch_capacity
self.prefetch_threads = prefetch_threads
self.data_ph_from_queue = None self.data_ph_from_queue = None
self.label_ph_from_queue = None self.label_ph_from_queue = None
......
...@@ -47,7 +47,11 @@ class Memory(Base): ...@@ -47,7 +47,11 @@ class Memory(Base):
batch_size=1, batch_size=1,
seed=10, seed=10,
data_augmentation=None, data_augmentation=None,
normalizer=Linear()): normalizer=Linear(),
prefetch=False,
prefetch_capacity=10,
prefetch_threads=5
):
super(Memory, self).__init__( super(Memory, self).__init__(
data=data, data=data,
...@@ -57,7 +61,10 @@ class Memory(Base): ...@@ -57,7 +61,10 @@ class Memory(Base):
batch_size=batch_size, batch_size=batch_size,
seed=seed, seed=seed,
data_augmentation=data_augmentation, data_augmentation=data_augmentation,
normalizer=normalizer normalizer=normalizer,
prefetch=prefetch,
prefetch_capacity=prefetch_capacity,
prefetch_threads=prefetch_threads
) )
# Seting the seed # Seting the seed
numpy.random.seed(seed) numpy.random.seed(seed)
......
...@@ -46,7 +46,7 @@ def main(): ...@@ -46,7 +46,7 @@ def main():
BASE_PATH = args['<base_path>'] BASE_PATH = args['<base_path>']
EXTENSION = args['--extension'] EXTENSION = args['--extension']
SHAPE = [3, 250, 250] SHAPE = [1, 224, 224]
count, sum_data = process_images(BASE_PATH, EXTENSION, SHAPE) count, sum_data = process_images(BASE_PATH, EXTENSION, SHAPE)
...@@ -54,4 +54,5 @@ def main(): ...@@ -54,4 +54,5 @@ def main():
for s in range(SHAPE[0]): for s in range(SHAPE[0]):
means[s, ...] = sum_data[s, ...] / float(count) means[s, ...] = sum_data[s, ...] / float(count)
bob.io.base.save(means, "means.hdf5") bob.io.base.save(means, "means_casia.hdf5")
bob.io.base.save(means[0, :, :].astype("uint8"), "means_casia.png")
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 13 Oct 2016 13:35 CEST
import numpy
from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, ImageAugmentation, ScaleFactor
from bob.learn.tensorflow.network import Chopra
from bob.learn.tensorflow.loss import BaseLoss, ContrastiveLoss, TripletLoss
from bob.learn.tensorflow.trainers import Trainer, SiameseTrainer, TripletTrainer, constant
from .test_cnn_scratch import validate_network
from bob.learn.tensorflow.network import Embedding
from bob.learn.tensorflow.utils import load_mnist
import tensorflow as tf
import bob.io.base
import shutil
from scipy.spatial.distance import cosine
import bob.measure
"""
Some unit tests for the datashuffler
"""
batch_size = 32
validation_batch_size = 400
iterations = 300
seed = 10
def test_cnn_trainer():
# Loading data
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
# Creating datashufflers
data_augmentation = ImageAugmentation()
train_data_shuffler = Memory(train_data, train_labels,
input_shape=[None, 28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation,
normalizer=ScaleFactor(),
prefetch=True,
prefetch_threads=1)
directory = "./temp/cnn"
# Loss for the softmax
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
# Preparing the architecture
architecture = Chopra(seed=seed,
fc1_output=10)
input_pl = train_data_shuffler("data", from_queue=True)
graph = architecture(input_pl)
embedding = Embedding(train_data_shuffler("data", from_queue=False), architecture(train_data_shuffler("data", from_queue=False), reuse=True))
# One graph trainer
trainer = Trainer(train_data_shuffler,
iterations=iterations,
analizer=None,
temp_dir=directory
)
trainer.create_network_from_scratch(graph=graph,
loss=loss,
learning_rate=constant(0.01, name="regular_lr"),
optimizer=tf.train.GradientDescentOptimizer(0.01),
)
trainer.train()
#trainer.train(validation_data_shuffler)
# Using embedding to compute the accuracy
accuracy = validate_network(embedding, validation_data, validation_labels)
# At least 80% of accuracy
assert accuracy > 80.
shutil.rmtree(directory)
del trainer
del graph
del embedding
tf.reset_default_graph()
...@@ -56,7 +56,8 @@ def validate_network(embedding, validation_data, validation_labels): ...@@ -56,7 +56,8 @@ def validate_network(embedding, validation_data, validation_labels):
def test_cnn_trainer_scratch(): def test_cnn_trainer_scratch():
tf.reset_default_graph()
train_data, train_labels, validation_data, validation_labels = load_mnist() train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1)) train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
......
...@@ -243,3 +243,23 @@ class SiameseTrainer(Trainer): ...@@ -243,3 +243,23 @@ class SiameseTrainer(Trainer):
#summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))] #summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
#self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step) #self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
logger.info("Loss VALIDATION set step={0} = {1}".format(step, l)) logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
def load_and_enqueue(self):
"""
Injecting data in the place holder queue
**Parameters**
session: Tensorflow session
"""
while not self.thread_pool.should_stop():
[train_data, train_labels] = self.train_data_shuffler.get_batch()
data_ph = self.train_data_shuffler("data", from_queue=False)
label_ph = self.train_data_shuffler("label", from_queue=False)
feed_dict = {data_ph: train_data,
label_ph: train_labels}
self.session.run(self.train_data_shuffler.enqueue_op, feed_dict=feed_dict)
...@@ -129,11 +129,11 @@ class Trainer(object): ...@@ -129,11 +129,11 @@ class Trainer(object):
learning_rate: Learning rate learning_rate: Learning rate
""" """
self.data_ph = self.train_data_shuffler("data") self.data_ph = self.train_data_shuffler("data", from_queue=True)
self.label_ph = self.train_data_shuffler("label") self.label_ph = self.train_data_shuffler("label", from_queue=True)
self.graph = graph self.graph = graph
self.loss = loss self.loss = loss
self.predictor = self.loss(self.graph, self.train_data_shuffler("label", from_queue=False)) self.predictor = self.loss(self.graph, self.label_ph)
self.optimizer_class = optimizer self.optimizer_class = optimizer
self.learning_rate = learning_rate self.learning_rate = learning_rate
...@@ -272,7 +272,7 @@ class Trainer(object): ...@@ -272,7 +272,7 @@ class Trainer(object):
""" """
threads = [] threads = []
for n in range(3): for n in range(self.train_data_shuffler.prefetch_threads):
t = threading.Thread(target=self.load_and_enqueue, args=()) t = threading.Thread(target=self.load_and_enqueue, args=())
t.daemon = True # thread will close when parent quits t.daemon = True # thread will close when parent quits
t.start() t.start()
...@@ -290,10 +290,13 @@ class Trainer(object): ...@@ -290,10 +290,13 @@ class Trainer(object):
while not self.thread_pool.should_stop(): while not self.thread_pool.should_stop():
[train_data, train_labels] = self.train_data_shuffler.get_batch() [train_data, train_labels] = self.train_data_shuffler.get_batch()
feed_dict = {self.data_ph: train_data, data_ph = self.train_data_shuffler("data", from_queue=False)
self.label_ph: train_labels} label_ph = self.train_data_shuffler("label", from_queue=False)
self.session.run(self.inputs.enqueue_op, feed_dict=feed_dict) feed_dict = {data_ph: train_data,
label_ph: train_labels}
self.session.run(self.train_data_shuffler.enqueue_op, feed_dict=feed_dict)
def train(self, validation_data_shuffler=None): def train(self, validation_data_shuffler=None):
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment