Skip to content
Snippets Groups Projects
Commit 615b9059 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented a trainer that lods a pretrained network

parent 0cbbe438
No related branches found
No related tags found
No related merge requests found
...@@ -225,17 +225,32 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -225,17 +225,32 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
self.sequence_net[k].weights_initialization.use_gpu = state self.sequence_net[k].weights_initialization.use_gpu = state
self.sequence_net[k].bias_initialization.use_gpu = state self.sequence_net[k].bias_initialization.use_gpu = state
def load(self, hdf5, shape=None, session=None, batch=1): def load_variables_only(self, hdf5, session):
""" """
Load the network Load the variables of the model
"""
hdf5.cd('/tensor_flow')
for k in self.sequence_net:
# TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
if not isinstance(self.sequence_net[k], MaxPooling):
self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name)).eval(session=session)
session.run(self.sequence_net[k].W)
self.sequence_net[k].b.assign(hdf5.read(self.sequence_net[k].b.name)).eval(session=session)
session.run(self.sequence_net[k].b)
hdf5.cd("..")
def load(self, hdf5, shape=None, session=None, batch=1, use_gpu=False):
"""
Load the network from scratch.
This will build the graphs
**Parameters** **Parameters**
hdf5: The saved network in the :py:class:`bob.io.base.HDF5File` format hdf5: The saved network in the :py:class:`bob.io.base.HDF5File` format
shape: Input shape of the network. If `None`, the internal shape will be assumed
shape: Input shape of the network session: An opened tensorflow `session <https://www.tensorflow.org/versions/r0.11/api_docs/python/client.html#Session>`_. If `None`, a new one will be opened
batch: The size of the batch
session: tensorflow `session <https://www.tensorflow.org/versions/r0.11/api_docs/python/client.html#Session>`_ use_gpu: Load all the variables in the GPU?
""" """
if session is None: if session is None:
...@@ -249,7 +264,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -249,7 +264,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
self.sequence_net = pickle.loads(hdf5.read('architecture')) self.sequence_net = pickle.loads(hdf5.read('architecture'))
self.deployment_shape = hdf5.read('deployment_shape') self.deployment_shape = hdf5.read('deployment_shape')
self.turn_gpu_onoff(False) self.turn_gpu_onoff(use_gpu)
if shape is None: if shape is None:
shape = self.deployment_shape shape = self.deployment_shape
...@@ -259,17 +274,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -259,17 +274,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
place_holder = tf.placeholder(tf.float32, shape=shape, name="load") place_holder = tf.placeholder(tf.float32, shape=shape, name="load")
self.compute_graph(place_holder) self.compute_graph(place_holder)
tf.initialize_all_variables().run(session=session) tf.initialize_all_variables().run(session=session)
self.load_variables_only(hdf5, session)
hdf5.cd('/tensor_flow')
for k in self.sequence_net:
# TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
if not isinstance(self.sequence_net[k], MaxPooling):
#self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name))
self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name)).eval(session=session)
session.run(self.sequence_net[k].W)
self.sequence_net[k].b.assign(hdf5.read(self.sequence_net[k].b.name)).eval(session=session)
session.run(self.sequence_net[k].b)
""" """
def save(self, session, path, step=None): def save(self, session, path, step=None):
......
...@@ -26,20 +26,7 @@ iterations = 50 ...@@ -26,20 +26,7 @@ iterations = 50
seed = 10 seed = 10
def test_cnn_trainer_scratch(): def scratch_network():
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
# Creating datashufflers
data_augmentation = ImageAugmentation()
train_data_shuffler = Memory(train_data, train_labels,
input_shape=[28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation)
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
directory = "./temp/cnn"
# Creating a random network # Creating a random network
scratch = SequenceNetwork() scratch = SequenceNetwork()
scratch.add(Conv2D(name="conv1", kernel_size=3, scratch.add(Conv2D(name="conv1", kernel_size=3,
...@@ -51,20 +38,11 @@ def test_cnn_trainer_scratch(): ...@@ -51,20 +38,11 @@ def test_cnn_trainer_scratch():
activation=None, activation=None,
weights_initialization=Xavier(seed=seed, use_gpu=False), weights_initialization=Xavier(seed=seed, use_gpu=False),
bias_initialization=Constant(use_gpu=False))) bias_initialization=Constant(use_gpu=False)))
# Loss for the softmax
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
# One graph trainer return scratch
trainer = Trainer(architecture=scratch,
loss=loss,
iterations=iterations,
analizer=None,
prefetch=False,
temp_dir=directory)
trainer.train(train_data_shuffler)
del scratch
def validate_network(validation_data, validation_labels, directory):
# Testing # Testing
validation_data_shuffler = Memory(validation_data, validation_labels, validation_data_shuffler = Memory(validation_data, validation_labels,
input_shape=[28, 28, 1], input_shape=[28, 28, 1],
...@@ -79,7 +57,40 @@ def test_cnn_trainer_scratch(): ...@@ -79,7 +57,40 @@ def test_cnn_trainer_scratch():
predictions = scratch(data, session=session) predictions = scratch(data, session=session)
accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == labels) / predictions.shape[0] accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == labels) / predictions.shape[0]
# At least 80% of accuracy return accuracy
assert accuracy > 80.
shutil.rmtree(directory)
def test_cnn_trainer_scratch():
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
# Creating datashufflers
data_augmentation = ImageAugmentation()
train_data_shuffler = Memory(train_data, train_labels,
input_shape=[28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation)
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
directory = "./temp/cnn"
# Create scratch network
scratch = scratch_network()
# Loss for the softmax
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
# One graph trainer
trainer = Trainer(architecture=scratch,
loss=loss,
iterations=iterations,
analizer=None,
prefetch=False,
temp_dir=directory)
trainer.train(train_data_shuffler)
accuracy = validate_network(validation_data, validation_labels, directory)
assert accuracy > 80
del scratch
...@@ -62,6 +62,8 @@ class SiameseTrainer(Trainer): ...@@ -62,6 +62,8 @@ class SiameseTrainer(Trainer):
## Analizer ## Analizer
analizer=ExperimentAnalizer(), analizer=ExperimentAnalizer(),
model_from_file="",
verbosity_level=2): verbosity_level=2):
super(SiameseTrainer, self).__init__( super(SiameseTrainer, self).__init__(
...@@ -86,6 +88,8 @@ class SiameseTrainer(Trainer): ...@@ -86,6 +88,8 @@ class SiameseTrainer(Trainer):
## Analizer ## Analizer
analizer=analizer, analizer=analizer,
model_from_file=model_from_file,
verbosity_level=verbosity_level verbosity_level=verbosity_level
) )
......
...@@ -65,6 +65,9 @@ class Trainer(object): ...@@ -65,6 +65,9 @@ class Trainer(object):
## Analizer ## Analizer
analizer=SoftmaxAnalizer(), analizer=SoftmaxAnalizer(),
### Pretrained model
model_from_file="",
verbosity_level=2): verbosity_level=2):
if not isinstance(architecture, SequenceNetwork): if not isinstance(architecture, SequenceNetwork):
...@@ -107,6 +110,8 @@ class Trainer(object): ...@@ -107,6 +110,8 @@ class Trainer(object):
self.enqueue_op = None self.enqueue_op = None
self.global_step = None self.global_step = None
self.model_from_file = model_from_file
bob.core.log.set_verbosity_level(logger, verbosity_level) bob.core.log.set_verbosity_level(logger, verbosity_level)
def __del__(self): def __del__(self):
...@@ -289,6 +294,12 @@ class Trainer(object): ...@@ -289,6 +294,12 @@ class Trainer(object):
with tf.Session(config=config) as session: with tf.Session(config=config) as session:
tf.initialize_all_variables().run() tf.initialize_all_variables().run()
# Loading a pretrained model
if self.model_from_file != "":
logger.info("Loading pretrained model from {0}".format(self.model_from_file))
hdf5 = bob.io.base.HDF5File(self.model_from_file)
self.architecture.load_variables_only(hdf5, session)
if isinstance(train_data_shuffler, OnLineSampling): if isinstance(train_data_shuffler, OnLineSampling):
train_data_shuffler.set_feature_extractor(self.architecture, session=session) train_data_shuffler.set_feature_extractor(self.architecture, session=session)
......
...@@ -62,6 +62,8 @@ class TripletTrainer(Trainer): ...@@ -62,6 +62,8 @@ class TripletTrainer(Trainer):
## Analizer ## Analizer
analizer=ExperimentAnalizer(), analizer=ExperimentAnalizer(),
model_from_file="",
verbosity_level=2): verbosity_level=2):
super(TripletTrainer, self).__init__( super(TripletTrainer, self).__init__(
...@@ -85,6 +87,7 @@ class TripletTrainer(Trainer): ...@@ -85,6 +87,7 @@ class TripletTrainer(Trainer):
## Analizer ## Analizer
analizer=analizer, analizer=analizer,
model_from_file=model_from_file,
verbosity_level=verbosity_level verbosity_level=verbosity_level
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment