New session management

parent 0fe1b037
......@@ -9,4 +9,5 @@ from bob.learn.tensorflow import layers
from bob.learn.tensorflow import loss
from bob.learn.tensorflow import network
from bob.learn.tensorflow import trainers
from bob.learn.tensorflow import utils
......@@ -2,7 +2,6 @@
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)
from .util import *
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
......@@ -4,7 +4,6 @@
# @date: Wed 11 May 2016 17:38 CEST
import tensorflow as tf
from bob.learn.tensorflow.util import *
from .MaxPooling import MaxPooling
......
......@@ -4,7 +4,6 @@
# @date: Wed 11 May 2016 17:38 CEST
import tensorflow as tf
from bob.learn.tensorflow.util import *
from .Layer import Layer
......
......@@ -4,7 +4,6 @@
# @date: Wed 11 May 2016 17:38 CEST
import tensorflow as tf
from bob.learn.tensorflow.util import *
from .Layer import Layer
......
......@@ -8,7 +8,7 @@ logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf
from .BaseLoss import BaseLoss
from bob.learn.tensorflow.util import compute_euclidean_distance
from bob.learn.tensorflow.utils import compute_euclidean_distance
class ContrastiveLoss(BaseLoss):
......
......@@ -8,7 +8,7 @@ logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf
from .BaseLoss import BaseLoss
from bob.learn.tensorflow.util import compute_euclidean_distance
from bob.learn.tensorflow.utils import compute_euclidean_distance
class TripletLoss(BaseLoss):
......
......@@ -12,6 +12,7 @@ import pickle
from collections import OrderedDict
from bob.learn.tensorflow.layers import Layer, MaxPooling, Dropout, Conv2D, FullyConnected
from bob.learn.tensorflow.utils.session import Session
class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
......@@ -102,7 +103,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
def compute_inference_placeholder(self, data_shape):
self.inference_placeholder = tf.placeholder(tf.float32, shape=data_shape, name="feature")
def __call__(self, data, session=None, feature_layer=None):
def __call__(self, data, feature_layer=None):
"""Run a graph and compute the embeddings
**Parameters**
......@@ -115,8 +116,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
If `None` will run the graph until the end.
"""
if session is None:
session = tf.Session()
session = Session.instance().session
# Feeding the placeholder
if self.inference_placeholder is None:
......@@ -130,8 +130,8 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
return embedding
def predict(self, data, session):
return numpy.argmax(self(data, session=session), 1)
def predict(self, data):
return numpy.argmax(self(data), 1)
def dump_variables(self):
"""
......@@ -252,10 +252,13 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
self.sequence_net[k].weights_initialization.use_gpu = state
self.sequence_net[k].bias_initialization.use_gpu = state
def load_variables_only(self, hdf5, session):
def load_variables_only(self, hdf5):
"""
Load the variables of the model
"""
session = Session.instance().session
hdf5.cd('/tensor_flow')
for k in self.sequence_net:
# TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
......@@ -271,7 +274,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
hdf5.cd("..")
def load_hdf5(self, hdf5, shape=None, session=None, batch=1, use_gpu=False):
def load_hdf5(self, hdf5, shape=None, batch=1, use_gpu=False):
"""
Load the network from scratch.
This will build the graphs
......@@ -285,8 +288,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
use_gpu: Load all the variables in the GPU?
"""
if session is None:
session = tf.Session()
session = Session.instance().session
# Loading the normalization parameters
self.input_divide = hdf5.read('input_divide')
......@@ -308,11 +310,17 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
tf.initialize_all_variables().run(session=session)
self.load_variables_only(hdf5, session)
def save(self, session, saver, path):
def save(self, saver, path):
session = Session.instance().session
open(path+"_sequence_net.pickle", 'w').write(self.pickle_architecture)
return saver.save(session, path)
def load(self, session, path, clear_devices=False):
def load(self, path, clear_devices=False):
session = Session.instance().session
self.sequence_net = pickle.loads(open(path+"_sequence_net.pickle").read())
#saver = tf.train.import_meta_graph(path + ".meta", clear_devices=clear_devices)
saver = tf.train.import_meta_graph(path + ".meta")
......
......@@ -11,7 +11,7 @@ from bob.learn.tensorflow.initialization import Xavier, Constant
from bob.learn.tensorflow.network import SequenceNetwork
from bob.learn.tensorflow.loss import BaseLoss
from bob.learn.tensorflow.trainers import Trainer
from bob.learn.tensorflow.util import load_mnist
from bob.learn.tensorflow.utils import load_mnist
from bob.learn.tensorflow.layers import Conv2D, FullyConnected, MaxPooling
import tensorflow as tf
import shutil
......@@ -46,22 +46,21 @@ def scratch_network():
return scratch
def validate_network(validation_data, validation_labels, directory):
def validate_network(validation_data, validation_labels, network):
# Testing
validation_data_shuffler = Memory(validation_data, validation_labels,
input_shape=[28, 28, 1],
batch_size=validation_batch_size)
with tf.Session() as session:
scratch = SequenceNetwork()
scratch.load(session, os.path.join(directory, "model.ckp"))
[data, labels] = validation_data_shuffler.get_batch()
predictions = scratch(data, session=session)
accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == labels) / predictions.shape[0]
[data, labels] = validation_data_shuffler.get_batch()
predictions = network.predict(data)
accuracy = 100. * numpy.sum(predictions == labels) / predictions.shape[0]
return accuracy
def test_cnn_trainer_scratch():
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
......@@ -86,11 +85,10 @@ def test_cnn_trainer_scratch():
analizer=None,
prefetch=False,
temp_dir=directory)
trainer.train(train_data_shuffler)
del trainer# JUst to clean the tf.variables
trainer.train(train_data_shuffler)
accuracy = validate_network(validation_data, validation_labels, directory)
accuracy = validate_network(validation_data, validation_labels, scratch)
assert accuracy > 80
shutil.rmtree(directory)
del trainer
......@@ -13,6 +13,7 @@ from ..analyzers import SoftmaxAnalizer
from tensorflow.core.framework import summary_pb2
import time
from bob.learn.tensorflow.datashuffler.OnlineSampling import OnLineSampling
from bob.learn.tensorflow.utils.session import Session
from .learning_rate import constant
logger = bob.core.log.setup("bob.learn.tensorflow")
......@@ -103,6 +104,7 @@ class Trainer(object):
self.global_step = None
self.model_from_file = model_from_file
self.session = None
bob.core.log.set_verbosity_level(logger, verbosity_level)
......@@ -162,7 +164,7 @@ class Trainer(object):
label_placeholder: labels}
return feed_dict
def fit(self, session, step):
def fit(self, step):
"""
Run one iteration (`forward` and `backward`)
......@@ -173,17 +175,17 @@ class Trainer(object):
"""
if self.prefetch:
_, l, lr, summary = session.run([self.optimizer, self.training_graph,
_, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
self.learning_rate, self.summaries_train])
else:
feed_dict = self.get_feed_dict(self.train_data_shuffler)
_, l, lr, summary = session.run([self.optimizer, self.training_graph,
_, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
self.learning_rate, self.summaries_train], feed_dict=feed_dict)
logger.info("Loss training set step={0} = {1}".format(step, l))
self.train_summary_writter.add_summary(summary, step)
def compute_validation(self, session, data_shuffler, step):
def compute_validation(self, data_shuffler, step):
"""
Computes the loss in the validation set
......@@ -195,10 +197,10 @@ class Trainer(object):
"""
# Opening a new session for validation
feed_dict = self.get_feed_dict(data_shuffler)
l = session.run(self.validation_graph, feed_dict=feed_dict)
l = self.session.run(self.validation_graph, feed_dict=feed_dict)
if self.validation_summary_writter is None:
self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph)
self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), self.session.graph)
summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
......@@ -213,7 +215,7 @@ class Trainer(object):
tf.scalar_summary('lr', self.learning_rate, name="train")
return tf.merge_all_summaries()
def start_thread(self, session):
def start_thread(self):
"""
Start pool of threads for pre-fetching
......@@ -223,13 +225,13 @@ class Trainer(object):
threads = []
for n in range(3):
t = threading.Thread(target=self.load_and_enqueue, args=(session,))
t = threading.Thread(target=self.load_and_enqueue, args=())
t.daemon = True # thread will close when parent quits
t.start()
threads.append(t)
return threads
def load_and_enqueue(self, session):
def load_and_enqueue(self):
"""
Injecting data in the place holder queue
......@@ -244,7 +246,7 @@ class Trainer(object):
feed_dict = {train_placeholder_data: train_data,
train_placeholder_labels: train_labels}
session.run(self.enqueue_op, feed_dict=feed_dict)
self.session.run(self.enqueue_op, feed_dict=feed_dict)
def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
"""
......@@ -293,7 +295,7 @@ class Trainer(object):
tf.add_to_collection("validation_placeholder_data", batch)
tf.add_to_collection("validation_placeholder_label", label)
def bootstrap_graphs_fromfile(self, session, train_data_shuffler, validation_data_shuffler):
def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler):
"""
Bootstrap all the necessary data from file
......@@ -304,7 +306,7 @@ class Trainer(object):
"""
saver = self.architecture.load(session, self.model_from_file)
saver = self.architecture.load(self.session, self.model_from_file)
# Loading training graph
self.training_graph = tf.get_collection("training_graph")[0]
......@@ -362,78 +364,80 @@ class Trainer(object):
# Pickle the architecture to save
self.architecture.pickle_net(train_data_shuffler.deployment_shape)
with tf.Session(config=config) as session:
# Loading a pretrained model
if self.model_from_file != "":
logger.info("Loading pretrained model from {0}".format(self.model_from_file))
saver = self.bootstrap_graphs_fromfile(session, train_data_shuffler, validation_data_shuffler)
else:
# Bootstraping all the graphs
self.bootstrap_graphs(train_data_shuffler, validation_data_shuffler)
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False)
# Preparing the optimizer
self.optimizer_class._learning_rate = self.learning_rate
self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
tf.add_to_collection("optimizer", self.optimizer)
tf.add_to_collection("learning_rate", self.learning_rate)
# Train summary
self.summaries_train = self.create_general_summary()
tf.add_to_collection("summaries_train", self.summaries_train)
tf.initialize_all_variables().run()
# Original tensorflow saver object
saver = tf.train.Saver(var_list=tf.all_variables())
if isinstance(train_data_shuffler, OnLineSampling):
train_data_shuffler.set_feature_extractor(self.architecture, session=session)
# Start a thread to enqueue data asynchronously, and hide I/O latency.
if self.prefetch:
self.thread_pool = tf.train.Coordinator()
tf.train.start_queue_runners(coord=self.thread_pool)
threads = self.start_thread(session)
# TENSOR BOARD SUMMARY
self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph)
for step in range(self.iterations):
start = time.time()
self.fit(session, step)
end = time.time()
summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)
# Running validation
if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
self.compute_validation(session, validation_data_shuffler, step)
if self.analizer is not None:
self.validation_summary_writter.add_summary(self.analizer(
validation_data_shuffler, self.architecture, session), step)
# Taking snapshot
if step % self.snapshot == 0:
logger.info("Taking snapshot")
path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
self.architecture.save(session, saver, path)
logger.info("Training finally finished")
self.train_summary_writter.close()
if validation_data_shuffler is not None:
self.validation_summary_writter.close()
# Saving the final network
path = os.path.join(self.temp_dir, 'model.ckp')
self.architecture.save(session, saver, path)
if self.prefetch:
# now they should definetely stop
self.thread_pool.request_stop()
self.thread_pool.join(threads)
#with tf.Session(config=config) as session:
self.session = Session.instance().session
# Loading a pretrained model
if self.model_from_file != "":
logger.info("Loading pretrained model from {0}".format(self.model_from_file))
saver = self.bootstrap_graphs_fromfile(self.session, train_data_shuffler, validation_data_shuffler)
else:
# Bootstraping all the graphs
self.bootstrap_graphs(train_data_shuffler, validation_data_shuffler)
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False)
# Preparing the optimizer
self.optimizer_class._learning_rate = self.learning_rate
self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
tf.add_to_collection("optimizer", self.optimizer)
tf.add_to_collection("learning_rate", self.learning_rate)
# Train summary
self.summaries_train = self.create_general_summary()
tf.add_to_collection("summaries_train", self.summaries_train)
tf.initialize_all_variables().run(session=self.session)
# Original tensorflow saver object
saver = tf.train.Saver(var_list=tf.all_variables())
if isinstance(train_data_shuffler, OnLineSampling):
train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)
# Start a thread to enqueue data asynchronously, and hide I/O latency.
if self.prefetch:
self.thread_pool = tf.train.Coordinator()
tf.train.start_queue_runners(coord=self.thread_pool)
threads = self.start_thread(self.session)
# TENSOR BOARD SUMMARY
self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
for step in range(self.iterations):
start = time.time()
self.fit(self.session, step)
end = time.time()
summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)
# Running validation
if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
self.compute_validation(self.session, validation_data_shuffler, step)
if self.analizer is not None:
self.validation_summary_writter.add_summary(self.analizer(
validation_data_shuffler, self.architecture, self.session), step)
# Taking snapshot
if step % self.snapshot == 0:
logger.info("Taking snapshot")
path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
self.architecture.save(saver, path)
logger.info("Training finally finished")
self.train_summary_writter.close()
if validation_data_shuffler is not None:
self.validation_summary_writter.close()
# Saving the final network
path = os.path.join(self.temp_dir, 'model.ckp')
self.architecture.save(saver, path)
if self.prefetch:
# now they should definetely stop
self.thread_pool.request_stop()
self.thread_pool.join(threads)
from .util import *
from .singleton import Singleton
from .session import Session
\ No newline at end of file
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST
import tensorflow as tf
from .singleton import Singleton
@Singleton
class Session(object):
def __init__(self):
config = tf.ConfigProto(log_device_placement=True,
gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.333))
config.gpu_options.allow_growth = True
self.session = tf.Session()
#def __del__(self):
# self.session.close()
\ No newline at end of file
# A singleton class decorator, based on http://stackoverflow.com/a/7346105/3301902
class Singleton(object):
"""
A non-thread-safe helper class to ease implementing singletons.
This should be used as a **decorator** -- not a metaclass -- to the class that should be a singleton.
The decorated class can define one `__init__` function that takes an arbitrary list of parameters.
To get the singleton instance, use the :py:meth:`instance` method. Trying to use `__call__` will result in a `TypeError` being raised.
Limitations:
* The decorated class cannot be inherited from.
* The documentation of the decorated class is replaced with the documentation of this class.
"""
def __init__(self, decorated):
self._decorated = decorated
# see: functools.WRAPPER_ASSIGNMENTS:
self.__doc__ = decorated.__doc__
self.__name__ = decorated.__name__
self.__module__ = decorated.__module__
self.__mro__ = decorated.__mro__
self.__bases__ = []
self._instance = None
def create(self, *args, **kwargs):
"""Creates the singleton instance, by passing the given parameters to the class' constructor."""
self._instance = self._decorated(*args, **kwargs)
def instance(self):
"""Returns the singleton instance.
The function :py:meth:`create` must have been called before."""
if self._instance is None:
self.create()
return self._instance
def __call__(self):
raise TypeError('Singletons must be accessed through the `instance()` method.')
def __instancecheck__(self, inst):
return isinstance(inst, self._decorated)
......@@ -63,7 +63,7 @@ Now lets describe each step in detail.
Preparing your input data
-------------------------
.........................
In this library datasets are wrapped in **data shufflers**. Data shufflers are elements designed to shuffle
the input data for stochastic training.
......@@ -73,24 +73,39 @@ It is possible to either use Memory (:py:class:`bob.learn.tensorflow.datashuffle
Disk (:py:class:`bob.learn.tensorflow.datashuffler.Disk`) data shufflers.
For the Memory data shufflers, as in the example, it is expected that the dataset is stored in `numpy.array`.
In the example that we provided the MNIST dataset was loaded and
reshaped to `[n, w, h, c]` where `n` is the size of the batch, `w` and `h` are the image width and height and `c` is the
In the example that we provided the MNIST dataset was loaded and reshaped to `[n, w, h, c]` where `n` is the size
of the batch, `w` and `h` are the image width and height and `c` is the
number of channels.
Creating the architecture
-------------------------
.........................
Architectures are assembled as a :py:class:`bob.learn.tensorflow.network.SequenceNetwork` object.
Once the objects are created it necessary to fill it up with :py_api:`Layers`_.
The library has already some crafted networks `Architectures`_
Once the objects are created it is necessary to fill it up with `Layers`_.
The library has already some crafted networks implemented in `Architectures`_
Defining a loss and training
----------------------------
............................
The loss function can be defined by any set of tensorflow operations.
In our example, we used the `tf.nn.sparse_softmax_cross_entropy_with_logits` loss, but we also have some crafted
loss functions for Siamese :py:class`bob.learn.tensorflow.loss.ContrastiveLoss` and Triplet networks :py:class`bob.learn.tensorflow.loss.TripletLoss`.
Predicting and computing the accuracy
-------------------------------------
The trainer is the real muscle here.
This element takes the inputs and trains the network.
As for the loss, we have specific trainers for Siamese (:py:class:`bob.learn.tensorflow.trainers.SiameseTrainer`) a
nd Triplet networks (:py:class:`bob.learn.tensorflow.trainers.TripletTrainer`).
Sandbox
-------
We have a sandbox of examples in a git repository `https://gitlab.idiap.ch/tiago.pereira/bob.learn.tensorflow_sandbox`_
The sandbox has some example of training:
- MNIST with softmax
- MNIST with Siamese Network
- MNIST with Triplet Network
- Face recognition with MOBIO database
- Face recognition with CASIA WebFace database
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment