New session management

parent 0fe1b037
...@@ -9,4 +9,5 @@ from bob.learn.tensorflow import layers ...@@ -9,4 +9,5 @@ from bob.learn.tensorflow import layers
from bob.learn.tensorflow import loss from bob.learn.tensorflow import loss
from bob.learn.tensorflow import network from bob.learn.tensorflow import network
from bob.learn.tensorflow import trainers from bob.learn.tensorflow import trainers
from bob.learn.tensorflow import utils
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from pkgutil import extend_path from pkgutil import extend_path
__path__ = extend_path(__path__, __name__) __path__ = extend_path(__path__, __name__)
from .util import *
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')] __all__ = [_ for _ in dir() if not _.startswith('_')]
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# @date: Wed 11 May 2016 17:38 CEST # @date: Wed 11 May 2016 17:38 CEST
import tensorflow as tf import tensorflow as tf
from bob.learn.tensorflow.util import *
from .MaxPooling import MaxPooling from .MaxPooling import MaxPooling
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# @date: Wed 11 May 2016 17:38 CEST # @date: Wed 11 May 2016 17:38 CEST
import tensorflow as tf import tensorflow as tf
from bob.learn.tensorflow.util import *
from .Layer import Layer from .Layer import Layer
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# @date: Wed 11 May 2016 17:38 CEST # @date: Wed 11 May 2016 17:38 CEST
import tensorflow as tf import tensorflow as tf
from bob.learn.tensorflow.util import *
from .Layer import Layer from .Layer import Layer
......
...@@ -8,7 +8,7 @@ logger = logging.getLogger("bob.learn.tensorflow") ...@@ -8,7 +8,7 @@ logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf import tensorflow as tf
from .BaseLoss import BaseLoss from .BaseLoss import BaseLoss
from bob.learn.tensorflow.util import compute_euclidean_distance from bob.learn.tensorflow.utils import compute_euclidean_distance
class ContrastiveLoss(BaseLoss): class ContrastiveLoss(BaseLoss):
......
...@@ -8,7 +8,7 @@ logger = logging.getLogger("bob.learn.tensorflow") ...@@ -8,7 +8,7 @@ logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf import tensorflow as tf
from .BaseLoss import BaseLoss from .BaseLoss import BaseLoss
from bob.learn.tensorflow.util import compute_euclidean_distance from bob.learn.tensorflow.utils import compute_euclidean_distance
class TripletLoss(BaseLoss): class TripletLoss(BaseLoss):
......
...@@ -12,6 +12,7 @@ import pickle ...@@ -12,6 +12,7 @@ import pickle
from collections import OrderedDict from collections import OrderedDict
from bob.learn.tensorflow.layers import Layer, MaxPooling, Dropout, Conv2D, FullyConnected from bob.learn.tensorflow.layers import Layer, MaxPooling, Dropout, Conv2D, FullyConnected
from bob.learn.tensorflow.utils.session import Session
class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
...@@ -102,7 +103,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -102,7 +103,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
def compute_inference_placeholder(self, data_shape): def compute_inference_placeholder(self, data_shape):
self.inference_placeholder = tf.placeholder(tf.float32, shape=data_shape, name="feature") self.inference_placeholder = tf.placeholder(tf.float32, shape=data_shape, name="feature")
def __call__(self, data, session=None, feature_layer=None): def __call__(self, data, feature_layer=None):
"""Run a graph and compute the embeddings """Run a graph and compute the embeddings
**Parameters** **Parameters**
...@@ -115,8 +116,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -115,8 +116,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
If `None` will run the graph until the end. If `None` will run the graph until the end.
""" """
if session is None: session = Session.instance().session
session = tf.Session()
# Feeding the placeholder # Feeding the placeholder
if self.inference_placeholder is None: if self.inference_placeholder is None:
...@@ -130,8 +130,8 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -130,8 +130,8 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
return embedding return embedding
def predict(self, data, session): def predict(self, data):
return numpy.argmax(self(data, session=session), 1) return numpy.argmax(self(data), 1)
def dump_variables(self): def dump_variables(self):
""" """
...@@ -252,10 +252,13 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -252,10 +252,13 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
self.sequence_net[k].weights_initialization.use_gpu = state self.sequence_net[k].weights_initialization.use_gpu = state
self.sequence_net[k].bias_initialization.use_gpu = state self.sequence_net[k].bias_initialization.use_gpu = state
def load_variables_only(self, hdf5, session): def load_variables_only(self, hdf5):
""" """
Load the variables of the model Load the variables of the model
""" """
session = Session.instance().session
hdf5.cd('/tensor_flow') hdf5.cd('/tensor_flow')
for k in self.sequence_net: for k in self.sequence_net:
# TODO: IT IS NOT SMART TESTING ALONG THIS PAGE # TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
...@@ -271,7 +274,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -271,7 +274,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
hdf5.cd("..") hdf5.cd("..")
def load_hdf5(self, hdf5, shape=None, session=None, batch=1, use_gpu=False): def load_hdf5(self, hdf5, shape=None, batch=1, use_gpu=False):
""" """
Load the network from scratch. Load the network from scratch.
This will build the graphs This will build the graphs
...@@ -285,8 +288,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -285,8 +288,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
use_gpu: Load all the variables in the GPU? use_gpu: Load all the variables in the GPU?
""" """
if session is None: session = Session.instance().session
session = tf.Session()
# Loading the normalization parameters # Loading the normalization parameters
self.input_divide = hdf5.read('input_divide') self.input_divide = hdf5.read('input_divide')
...@@ -308,11 +310,17 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -308,11 +310,17 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
tf.initialize_all_variables().run(session=session) tf.initialize_all_variables().run(session=session)
self.load_variables_only(hdf5, session) self.load_variables_only(hdf5, session)
def save(self, session, saver, path): def save(self, saver, path):
session = Session.instance().session
open(path+"_sequence_net.pickle", 'w').write(self.pickle_architecture) open(path+"_sequence_net.pickle", 'w').write(self.pickle_architecture)
return saver.save(session, path) return saver.save(session, path)
def load(self, session, path, clear_devices=False): def load(self, path, clear_devices=False):
session = Session.instance().session
self.sequence_net = pickle.loads(open(path+"_sequence_net.pickle").read()) self.sequence_net = pickle.loads(open(path+"_sequence_net.pickle").read())
#saver = tf.train.import_meta_graph(path + ".meta", clear_devices=clear_devices) #saver = tf.train.import_meta_graph(path + ".meta", clear_devices=clear_devices)
saver = tf.train.import_meta_graph(path + ".meta") saver = tf.train.import_meta_graph(path + ".meta")
......
...@@ -11,7 +11,7 @@ from bob.learn.tensorflow.initialization import Xavier, Constant ...@@ -11,7 +11,7 @@ from bob.learn.tensorflow.initialization import Xavier, Constant
from bob.learn.tensorflow.network import SequenceNetwork from bob.learn.tensorflow.network import SequenceNetwork
from bob.learn.tensorflow.loss import BaseLoss from bob.learn.tensorflow.loss import BaseLoss
from bob.learn.tensorflow.trainers import Trainer from bob.learn.tensorflow.trainers import Trainer
from bob.learn.tensorflow.util import load_mnist from bob.learn.tensorflow.utils import load_mnist
from bob.learn.tensorflow.layers import Conv2D, FullyConnected, MaxPooling from bob.learn.tensorflow.layers import Conv2D, FullyConnected, MaxPooling
import tensorflow as tf import tensorflow as tf
import shutil import shutil
...@@ -46,22 +46,21 @@ def scratch_network(): ...@@ -46,22 +46,21 @@ def scratch_network():
return scratch return scratch
def validate_network(validation_data, validation_labels, directory): def validate_network(validation_data, validation_labels, network):
# Testing # Testing
validation_data_shuffler = Memory(validation_data, validation_labels, validation_data_shuffler = Memory(validation_data, validation_labels,
input_shape=[28, 28, 1], input_shape=[28, 28, 1],
batch_size=validation_batch_size) batch_size=validation_batch_size)
with tf.Session() as session:
scratch = SequenceNetwork() [data, labels] = validation_data_shuffler.get_batch()
scratch.load(session, os.path.join(directory, "model.ckp")) predictions = network.predict(data)
[data, labels] = validation_data_shuffler.get_batch() accuracy = 100. * numpy.sum(predictions == labels) / predictions.shape[0]
predictions = scratch(data, session=session)
accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == labels) / predictions.shape[0]
return accuracy return accuracy
def test_cnn_trainer_scratch(): def test_cnn_trainer_scratch():
train_data, train_labels, validation_data, validation_labels = load_mnist() train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1)) train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
...@@ -86,11 +85,10 @@ def test_cnn_trainer_scratch(): ...@@ -86,11 +85,10 @@ def test_cnn_trainer_scratch():
analizer=None, analizer=None,
prefetch=False, prefetch=False,
temp_dir=directory) temp_dir=directory)
trainer.train(train_data_shuffler)
del trainer# JUst to clean the tf.variables trainer.train(train_data_shuffler)
accuracy = validate_network(validation_data, validation_labels, directory) accuracy = validate_network(validation_data, validation_labels, scratch)
assert accuracy > 80 assert accuracy > 80
shutil.rmtree(directory) shutil.rmtree(directory)
del trainer
This diff is collapsed.
from .util import *
from .singleton import Singleton
from .session import Session
\ No newline at end of file
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST
import tensorflow as tf
from .singleton import Singleton
@Singleton
class Session(object):
def __init__(self):
config = tf.ConfigProto(log_device_placement=True,
gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.333))
config.gpu_options.allow_growth = True
self.session = tf.Session()
#def __del__(self):
# self.session.close()
\ No newline at end of file
# A singleton class decorator, based on http://stackoverflow.com/a/7346105/3301902
class Singleton(object):
"""
A non-thread-safe helper class to ease implementing singletons.
This should be used as a **decorator** -- not a metaclass -- to the class that should be a singleton.
The decorated class can define one `__init__` function that takes an arbitrary list of parameters.
To get the singleton instance, use the :py:meth:`instance` method. Trying to use `__call__` will result in a `TypeError` being raised.
Limitations:
* The decorated class cannot be inherited from.
* The documentation of the decorated class is replaced with the documentation of this class.
"""
def __init__(self, decorated):
self._decorated = decorated
# see: functools.WRAPPER_ASSIGNMENTS:
self.__doc__ = decorated.__doc__
self.__name__ = decorated.__name__
self.__module__ = decorated.__module__
self.__mro__ = decorated.__mro__
self.__bases__ = []
self._instance = None
def create(self, *args, **kwargs):
"""Creates the singleton instance, by passing the given parameters to the class' constructor."""
self._instance = self._decorated(*args, **kwargs)
def instance(self):
"""Returns the singleton instance.
The function :py:meth:`create` must have been called before."""
if self._instance is None:
self.create()
return self._instance
def __call__(self):
raise TypeError('Singletons must be accessed through the `instance()` method.')
def __instancecheck__(self, inst):
return isinstance(inst, self._decorated)
...@@ -63,7 +63,7 @@ Now lets describe each step in detail. ...@@ -63,7 +63,7 @@ Now lets describe each step in detail.
Preparing your input data Preparing your input data
------------------------- .........................
In this library datasets are wrapped in **data shufflers**. Data shufflers are elements designed to shuffle In this library datasets are wrapped in **data shufflers**. Data shufflers are elements designed to shuffle
the input data for stochastic training. the input data for stochastic training.
...@@ -73,24 +73,39 @@ It is possible to either use Memory (:py:class:`bob.learn.tensorflow.datashuffle ...@@ -73,24 +73,39 @@ It is possible to either use Memory (:py:class:`bob.learn.tensorflow.datashuffle
Disk (:py:class:`bob.learn.tensorflow.datashuffler.Disk`) data shufflers. Disk (:py:class:`bob.learn.tensorflow.datashuffler.Disk`) data shufflers.
For the Memory data shufflers, as in the example, it is expected that the dataset is stored in `numpy.array`. For the Memory data shufflers, as in the example, it is expected that the dataset is stored in `numpy.array`.
In the example that we provided the MNIST dataset was loaded and In the example that we provided the MNIST dataset was loaded and reshaped to `[n, w, h, c]` where `n` is the size
reshaped to `[n, w, h, c]` where `n` is the size of the batch, `w` and `h` are the image width and height and `c` is the of the batch, `w` and `h` are the image width and height and `c` is the
number of channels. number of channels.
Creating the architecture Creating the architecture
------------------------- .........................
Architectures are assembled as a :py:class:`bob.learn.tensorflow.network.SequenceNetwork` object. Architectures are assembled as a :py:class:`bob.learn.tensorflow.network.SequenceNetwork` object.
Once the objects are created it necessary to fill it up with :py_api:`Layers`_. Once the objects are created it is necessary to fill it up with `Layers`_.
The library has already some crafted networks `Architectures`_ The library has already some crafted networks implemented in `Architectures`_
Defining a loss and training Defining a loss and training
---------------------------- ............................
The loss function can be defined by any set of tensorflow operations.
In our example, we used the `tf.nn.sparse_softmax_cross_entropy_with_logits` loss, but we also have some crafted
loss functions for Siamese :py:class`bob.learn.tensorflow.loss.ContrastiveLoss` and Triplet networks :py:class`bob.learn.tensorflow.loss.TripletLoss`.
Predicting and computing the accuracy The trainer is the real muscle here.
------------------------------------- This element takes the inputs and trains the network.
As for the loss, we have specific trainers for Siamese (:py:class:`bob.learn.tensorflow.trainers.SiameseTrainer`) a
nd Triplet networks (:py:class:`bob.learn.tensorflow.trainers.TripletTrainer`).
Sandbox
-------
We have a sandbox of examples in a git repository `https://gitlab.idiap.ch/tiago.pereira/bob.learn.tensorflow_sandbox`_
The sandbox has some example of training:
- MNIST with softmax
- MNIST with Siamese Network
- MNIST with Triplet Network
- Face recognition with MOBIO database
- Face recognition with CASIA WebFace database
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment