First attempt with tfrecord

parent 65e93dc0
Pipeline #11882 failed with stages
in 14 minutes and 37 seconds
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST
import numpy
import tensorflow as tf
import bob.ip.base
import numpy
from bob.learn.tensorflow.datashuffler.Normalizer import Linear
class TFRecord(object):
"""
The class generate batches using tfrecord
**Parameters**
filename:
Name of the tf record
input_shape:
The shape of the inputs
input_dtype:
The type of the data,
batch_size:
Batch size
seed:
The seed of the random number generator
data_augmentation:
The algorithm used for data augmentation. Look :py:class:`bob.learn.tensorflow.datashuffler.DataAugmentation`
normalizer:
The algorithm used for feature scaling. Look :py:class:`bob.learn.tensorflow.datashuffler.ScaleFactor`, :py:class:`bob.learn.tensorflow.datashuffler.Linear` and :py:class:`bob.learn.tensorflow.datashuffler.MeanOffset`
prefetch:
Do prefetch?
prefetch_capacity:
"""
def __init__(self, filename_queue,
input_shape=[None, 28, 28, 1],
input_dtype="float32",
batch_size=32,
seed=10,
prefetch_capacity=50,
prefetch_threads=5):
# Setting the seed for the pseudo random number generator
self.seed = seed
numpy.random.seed(seed)
self.input_dtype = input_dtype
# TODO: Check if the bacth size is higher than the input data
self.batch_size = batch_size
# Preparing the inputs
self.filename_queue = filename_queue
self.input_shape = tuple(input_shape)
# Prefetch variables
self.prefetch_capacity = prefetch_capacity
self.prefetch_threads = prefetch_threads
# Preparing placeholders
self.data_ph = None
self.label_ph = None
self.data_ph_from_queue = None
self.label_ph_from_queue = None
self.prefetch = False
def create_placeholders(self):
"""
Create place holder instances
:return:
"""
feature = {'train/data': tf.FixedLenFeature([], tf.string),
'train/label': tf.FixedLenFeature([], tf.int64)}
# Define a reader and read the next record
reader = tf.TFRecordReader()
_, serialized_example = reader.read(self.filename_queue)
# Decode the record read by the reader
features = tf.parse_single_example(serialized_example, features=feature)
# Convert the image data from string back to the numbers
image = tf.decode_raw(features['train/data'], tf.float32)
# Cast label data into int32
label = tf.cast(features['train/label'], tf.int64)
# Reshape image data into the original shape
image = tf.reshape(image, self.input_shape[1:])
images, labels = tf.train.shuffle_batch([image, label], batch_size=32, capacity=1000, num_threads=1, min_after_dequeue=1)
self.data_ph = images
self.label_ph = labels
self.data_ph_from_queue = self.data_ph
self.label_ph_from_queue = self.label_ph
def __call__(self, element, from_queue=False):
"""
Return the necessary placeholder
"""
if not element in ["data", "label"]:
raise ValueError("Value '{0}' invalid. Options available are {1}".format(element, self.placeholder_options))
# If None, create the placeholders from scratch
if self.data_ph is None:
self.create_placeholders()
if element == "data":
if from_queue:
return self.data_ph_from_queue
else:
return self.data_ph
else:
if from_queue:
return self.label_ph_from_queue
else:
return self.label_ph
def get_batch(self):
"""
Shuffle the Memory dataset and get a random batch.
** Returns **
data:
Selected samples
labels:
Correspondent labels
"""
pass
......@@ -21,6 +21,7 @@ from .ImageAugmentation import ImageAugmentation
from .Normalizer import ScaleFactor, MeanOffset, Linear
from .DiskAudio import DiskAudio
from .TFRecord import TFRecord
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
......
from .Chopra import Chopra
#from .LightCNN9 import LightCNN9
from .LightCNN9 import LightCNN9
from .Dummy import Dummy
from .MLP import MLP
from .Embedding import Embedding
......@@ -20,6 +20,7 @@ def __appropriate__(*args):
__appropriate__(
Chopra,
LightCNN9,
Dummy,
MLP,
)
......
......@@ -9,7 +9,7 @@ from bob.learn.tensorflow.network import Chopra
from bob.learn.tensorflow.loss import BaseLoss, ContrastiveLoss, TripletLoss
from bob.learn.tensorflow.trainers import Trainer, SiameseTrainer, TripletTrainer, constant
from .test_cnn_scratch import validate_network
from bob.learn.tensorflow.network import Embedding
from bob.learn.tensorflow.network import Embedding, LightCNN9
from bob.learn.tensorflow.utils import load_mnist
import tensorflow as tf
......@@ -123,7 +123,7 @@ def test_cnn_trainer():
del trainer
del graph
"""
def test_lightcnn_trainer():
# generating fake data
......@@ -171,7 +171,6 @@ def test_lightcnn_trainer():
#trainer.train(validation_data_shuffler)
# Using embedding to compute the accuracy
import ipdb; ipdb.set_trace();
accuracy = validate_network(embedding, validation_data, validation_labels, input_shape=[None, 128, 128, 1], normalizer=Linear())
# At least 80% of accuracy
assert accuracy > 80.
......@@ -276,4 +275,4 @@ def test_tripletcnn_trainer():
del architecture
del trainer # Just to clean tf.variables
"""
......@@ -4,7 +4,7 @@
# @date: Thu 13 Oct 2016 13:35 CEST
import numpy
from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, ScaleFactor, Linear
from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, ScaleFactor, Linear, TFRecord
from bob.learn.tensorflow.network import Embedding
from bob.learn.tensorflow.loss import BaseLoss
from bob.learn.tensorflow.trainers import Trainer, constant
......@@ -96,3 +96,43 @@ def test_cnn_trainer_scratch():
assert accuracy > 70
shutil.rmtree(directory)
del trainer
def test_cnn_trainer_scratch_tfrecord():
tf.reset_default_graph()
#train_data, train_labels, validation_data, validation_labels = load_mnist()
#train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
tfrecords_filename = "/idiap/user/tpereira/gitlab/workspace_HTFace/mnist_train.tfrecords"
filename_queue = tf.train.string_input_producer([tfrecords_filename], num_epochs=1)
train_data_shuffler = TFRecord(filename_queue=filename_queue,
batch_size=batch_size)
# Creating datashufflers
# Create scratch network
graph = scratch_network(train_data_shuffler)
# Setting the placeholders
# Loss for the softmax
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
# One graph trainer
trainer = Trainer(train_data_shuffler,
iterations=iterations,
analizer=None,
temp_dir=directory)
trainer.create_network_from_scratch(graph=graph,
loss=loss,
learning_rate=constant(0.01, name="regular_lr"),
optimizer=tf.train.GradientDescentOptimizer(0.01),
)
trainer.train()
#accuracy = validate_network(embedding, validation_data, validation_labels)
#assert accuracy > 70
#shutil.rmtree(directory)
#del trainer
......@@ -11,7 +11,7 @@ import bob.core
from ..analyzers import SoftmaxAnalizer
from tensorflow.core.framework import summary_pb2
import time
from bob.learn.tensorflow.datashuffler import OnlineSampling
from bob.learn.tensorflow.datashuffler import OnlineSampling, TFRecord
from bob.learn.tensorflow.utils.session import Session
from .learning_rate import constant
import time
......@@ -221,7 +221,7 @@ class Trainer(object):
"""
if self.train_data_shuffler.prefetch:
if self.train_data_shuffler.prefetch or isinstance(self.train_data_shuffler, TFRecord):
_, l, lr, summary = self.session.run([self.optimizer, self.predictor,
self.learning_rate, self.summaries_train])
else:
......@@ -325,7 +325,17 @@ class Trainer(object):
self.thread_pool = tf.train.Coordinator()
tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
threads = self.start_thread()
time.sleep(20) # As suggested in https://stackoverflow.com/questions/39840323/benchmark-of-howto-reading-data/39842628#39842628
#time.sleep(20) # As suggested in https://stackoverflow.com/questions/39840323/benchmark-of-howto-reading-data/39842628#39842628
# TODO: JUST FOR TESTING THE INTEGRATION
import ipdb; ipdb.set_trace();
if isinstance(self.train_data_shuffler, TFRecord):
self.session.run(tf.local_variables_initializer())
self.thread_pool = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
# TENSOR BOARD SUMMARY
self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
......@@ -366,7 +376,7 @@ class Trainer(object):
path = os.path.join(self.temp_dir, 'model.ckp')
self.saver.save(self.session, path)
if self.train_data_shuffler.prefetch:
if self.train_data_shuffler.prefetch or isinstance(self.train_data_shuffler, TFRecord):
# now they should definetely stop
self.thread_pool.request_stop()
#self.thread_pool.join(threads)
self.thread_pool.join(threads)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment