tf-slim prototype

parent 421b1ded
Pipeline #8073 failed with stages
in 3 minutes and 8 seconds
......@@ -4,6 +4,7 @@
# @date: Tue 09 Aug 2016 16:38 CEST
import logging
import tensorflow as tf
logger = logging.getLogger("bob.learn.tensorflow")
......
......@@ -9,6 +9,7 @@ from .FaceNetSimple import FaceNetSimple
from .VGG16 import VGG16
from .VGG16_mod import VGG16_mod
from .SimpleAudio import SimpleAudio
from .Embedding import Embedding
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
......
......@@ -6,10 +6,10 @@
import numpy
import bob.io.base
import os
from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation
from bob.learn.tensorflow.network import SequenceNetwork
from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, ScaleFactor
from bob.learn.tensorflow.network import Embedding
from bob.learn.tensorflow.loss import BaseLoss
from bob.learn.tensorflow.trainers import Trainer
from bob.learn.tensorflow.trainers import Trainer, learning_rate
from bob.learn.tensorflow.utils import load_mnist
from bob.learn.tensorflow.layers import Conv2D, FullyConnected
import tensorflow as tf
......@@ -21,35 +21,40 @@ Some unit tests that create networks on the fly
batch_size = 16
validation_batch_size = 400
iterations = 50
iterations = 300
seed = 10
directory = "./temp/cnn_scratch"
slim = tf.contrib.slim
def scratch_network():
# Creating a random network
scratch = SequenceNetwork(default_feature_layer="fc1")
scratch.add(Conv2D(name="conv1", kernel_size=3,
filters=10,
activation=tf.nn.tanh,
batch_norm=False))
scratch.add(FullyConnected(name="fc1", output_dim=10,
activation=None,
batch_norm=False
))
return scratch
inputs = {}
inputs['data'] = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name="train_data")
inputs['label'] = tf.placeholder(tf.int64, shape=[None], name="train_label")
initializer = tf.contrib.layers.xavier_initializer(seed=seed)
scratch = slim.conv2d(inputs['data'], 10, [3, 3], activation_fn=tf.nn.relu, stride=1, scope='conv1',
weights_initializer=initializer)
scratch = slim.max_pool2d(scratch, [4, 4], scope='pool1')
scratch = slim.flatten(scratch, scope='flatten1')
scratch = slim.fully_connected(scratch, 10, activation_fn=None, scope='fc1',
weights_initializer=initializer)
return inputs, scratch
def validate_network(validation_data, validation_labels, network):
def validate_network(embedding, validation_data, validation_labels):
# Testing
validation_data_shuffler = Memory(validation_data, validation_labels,
input_shape=[28, 28, 1],
batch_size=validation_batch_size)
batch_size=validation_batch_size,
normalizer=ScaleFactor())
[data, labels] = validation_data_shuffler.get_batch()
predictions = network.predict(data)
accuracy = 100. * numpy.sum(predictions == labels) / predictions.shape[0]
predictions = embedding(data)
accuracy = 100. * numpy.sum(numpy.argmax(predictions, axis=1) == labels) / predictions.shape[0]
return accuracy
......@@ -64,27 +69,39 @@ def test_cnn_trainer_scratch():
train_data_shuffler = Memory(train_data, train_labels,
input_shape=[28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation)
data_augmentation=data_augmentation,
normalizer=ScaleFactor())
validation_data_shuffler = Memory(train_data, train_labels,
input_shape=[28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation,
normalizer=ScaleFactor())
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
# Create scratch network
scratch = scratch_network()
inputs, scratch = scratch_network()
embedding = Embedding(inputs['data'], scratch)
# Loss for the softmax
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
# One graph trainer
trainer = Trainer(architecture=scratch,
loss=loss,
trainer = Trainer(inputs=inputs,
graph=scratch,
iterations=iterations,
loss=loss,
analizer=None,
prefetch=False,
temp_dir=directory
temp_dir=directory,
optimizer=tf.train.GradientDescentOptimizer(0.01),
learning_rate=learning_rate.constant(base_learning_rate=0.01, name="constant_learning_rate"),
validation_snapshot=20
)
trainer.train(train_data_shuffler)
trainer.train(train_data_shuffler, validation_data_shuffler)
accuracy = validate_network(validation_data, validation_labels, scratch)
accuracy = validate_network(embedding, validation_data, validation_labels)
assert accuracy > 80
shutil.rmtree(directory)
#shutil.rmtree(directory)
del trainer
......@@ -69,6 +69,7 @@ class Trainer(object):
"""
def __init__(self,
inputs,
graph,
optimizer=tf.train.AdamOptimizer(),
use_gpu=False,
......@@ -96,10 +97,13 @@ class Trainer(object):
#if not isinstance(graph, SequenceNetwork):
# raise ValueError("`architecture` should be instance of `SequenceNetwork`")
self.inputs = inputs
self.graph = graph
self.loss = loss
self.predictor = self.loss(self.graph, inputs['label'])
self.optimizer_class = optimizer
self.use_gpu = use_gpu
self.loss = loss
self.temp_dir = temp_dir
if learning_rate is None and model_from_file == "":
......@@ -187,10 +191,9 @@ class Trainer(object):
"""
[data, labels] = data_shuffler.get_batch()
[data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
feed_dict = {data_placeholder: data,
label_placeholder: labels}
feed_dict = {self.inputs['data']: data,
self.inputs['label']: labels}
return feed_dict
def fit(self, step):
......@@ -204,27 +207,27 @@ class Trainer(object):
"""
if self.prefetch:
_, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
_, l, lr, summary = self.session.run([self.optimizer, self.predictor,
self.learning_rate, self.summaries_train])
else:
feed_dict = self.get_feed_dict(self.train_data_shuffler)
_, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
_, l, lr, summary = self.session.run([self.optimizer, self.predictor,
self.learning_rate, self.summaries_train], feed_dict=feed_dict)
logger.info("Loss training set step={0} = {1}".format(step, l))
self.train_summary_writter.add_summary(summary, step)
"""
def create_general_summary(self):
"""
Creates a simple tensorboard summary with the value of the loss and learning rate
"""
# Train summary
tf.summary.scalar('loss', self.training_graph)
tf.summary.scalar('loss', self.predictor)
tf.summary.scalar('lr', self.learning_rate)
return tf.summary.merge_all()
"""
def start_thread(self):
Start pool of threads for pre-fetching
......@@ -289,6 +292,24 @@ class Trainer(object):
return saver
def compute_validation(self, data_shuffler, step):
"""
Computes the loss in the validation set
** Parameters **
session: Tensorflow session
data_shuffler: The data shuffler to be used
step: Iteration number
"""
# Opening a new session for validation
feed_dict = self.get_feed_dict(data_shuffler)
l = self.session.run(self.predictor, feed_dict=feed_dict)
summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
def train(self, train_data_shuffler, validation_data_shuffler=None):
"""
Train the network:
......@@ -305,10 +326,7 @@ class Trainer(object):
logger.info("Initializing !!")
# Pickle the architecture to save
#self.architecture.pickle_net(train_data_shuffler.deployment_shape)
if not isinstance(tf.Tensor, self.graph):
if not isinstance(self.graph, tf.Tensor):
raise NotImplemented("Not tensor still not implemented")
self.session = Session.instance(new=True).session
......@@ -329,10 +347,13 @@ class Trainer(object):
# Preparing the optimizer
self.optimizer_class._learning_rate = self.learning_rate
self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
tf.add_to_collection("optimizer", self.optimizer)
tf.add_to_collection("learning_rate", self.learning_rate)
self.summaries_train = self.create_general_summary()
tf.add_to_collection("summaries_train", self.summaries_train)
# Train summary
tf.global_variables_initializer().run(session=self.session)
......@@ -350,6 +371,10 @@ class Trainer(object):
# TENSOR BOARD SUMMARY
self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
if validation_data_shuffler is not None:
self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'),
self.session.graph)
for step in range(start_step, self.iterations):
start = time.time()
self.fit(step)
......@@ -358,18 +383,19 @@ class Trainer(object):
self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)
# Running validation
#if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
# self.compute_validation(validation_data_shuffler, step)
if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
self.compute_validation(validation_data_shuffler, step)
# if self.analizer is not None:
# self.validation_summary_writter.add_summary(self.analizer(
# validation_data_shuffler, self.architecture, self.session), step)
#if self.analizer is not None:
# self.validation_summary_writter.add_summary(self.analizer(
# validation_data_shuffler, self.architecture, self.session), step)
# Taking snapshot
if step % self.snapshot == 0:
logger.info("Taking snapshot")
path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
self.architecture.save(saver, path)
saver.save(self.session, path)
#self.architecture.save(saver, path)
logger.info("Training finally finished")
......@@ -379,9 +405,9 @@ class Trainer(object):
# Saving the final network
path = os.path.join(self.temp_dir, 'model.ckp')
self.architecture.save(saver, path)
saver.save(self.session, path)
if self.prefetch:
# now they should definetely stop
self.thread_pool.request_stop()
self.thread_pool.join(threads)
#self.thread_pool.join(threads)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment