Implemented a trainer that lods a pretrained network

parent 0cbbe438
......@@ -225,17 +225,32 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
self.sequence_net[k].weights_initialization.use_gpu = state
self.sequence_net[k].bias_initialization.use_gpu = state
def load(self, hdf5, shape=None, session=None, batch=1):
def load_variables_only(self, hdf5, session):
"""
Load the network
Load the variables of the model
"""
hdf5.cd('/tensor_flow')
for k in self.sequence_net:
# TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
if not isinstance(self.sequence_net[k], MaxPooling):
self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name)).eval(session=session)
session.run(self.sequence_net[k].W)
self.sequence_net[k].b.assign(hdf5.read(self.sequence_net[k].b.name)).eval(session=session)
session.run(self.sequence_net[k].b)
hdf5.cd("..")
def load(self, hdf5, shape=None, session=None, batch=1, use_gpu=False):
"""
Load the network from scratch.
This will build the graphs
**Parameters**
hdf5: The saved network in the :py:class:`bob.io.base.HDF5File` format
shape: Input shape of the network
session: tensorflow `session <https://www.tensorflow.org/versions/r0.11/api_docs/python/client.html#Session>`_
shape: Input shape of the network. If `None`, the internal shape will be assumed
session: An opened tensorflow `session <https://www.tensorflow.org/versions/r0.11/api_docs/python/client.html#Session>`_. If `None`, a new one will be opened
batch: The size of the batch
use_gpu: Load all the variables in the GPU?
"""
if session is None:
......@@ -249,7 +264,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
self.sequence_net = pickle.loads(hdf5.read('architecture'))
self.deployment_shape = hdf5.read('deployment_shape')
self.turn_gpu_onoff(False)
self.turn_gpu_onoff(use_gpu)
if shape is None:
shape = self.deployment_shape
......@@ -259,17 +274,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
place_holder = tf.placeholder(tf.float32, shape=shape, name="load")
self.compute_graph(place_holder)
tf.initialize_all_variables().run(session=session)
hdf5.cd('/tensor_flow')
for k in self.sequence_net:
# TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
if not isinstance(self.sequence_net[k], MaxPooling):
#self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name))
self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name)).eval(session=session)
session.run(self.sequence_net[k].W)
self.sequence_net[k].b.assign(hdf5.read(self.sequence_net[k].b.name)).eval(session=session)
session.run(self.sequence_net[k].b)
self.load_variables_only(hdf5, session)
"""
def save(self, session, path, step=None):
......
......@@ -26,20 +26,7 @@ iterations = 50
seed = 10
def test_cnn_trainer_scratch():
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
# Creating datashufflers
data_augmentation = ImageAugmentation()
train_data_shuffler = Memory(train_data, train_labels,
input_shape=[28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation)
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
directory = "./temp/cnn"
def scratch_network():
# Creating a random network
scratch = SequenceNetwork()
scratch.add(Conv2D(name="conv1", kernel_size=3,
......@@ -51,20 +38,11 @@ def test_cnn_trainer_scratch():
activation=None,
weights_initialization=Xavier(seed=seed, use_gpu=False),
bias_initialization=Constant(use_gpu=False)))
# Loss for the softmax
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
# One graph trainer
trainer = Trainer(architecture=scratch,
loss=loss,
iterations=iterations,
analizer=None,
prefetch=False,
temp_dir=directory)
trainer.train(train_data_shuffler)
return scratch
del scratch
def validate_network(validation_data, validation_labels, directory):
# Testing
validation_data_shuffler = Memory(validation_data, validation_labels,
input_shape=[28, 28, 1],
......@@ -79,7 +57,40 @@ def test_cnn_trainer_scratch():
predictions = scratch(data, session=session)
accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == labels) / predictions.shape[0]
# At least 80% of accuracy
assert accuracy > 80.
shutil.rmtree(directory)
return accuracy
def test_cnn_trainer_scratch():
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
# Creating datashufflers
data_augmentation = ImageAugmentation()
train_data_shuffler = Memory(train_data, train_labels,
input_shape=[28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation)
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
directory = "./temp/cnn"
# Create scratch network
scratch = scratch_network()
# Loss for the softmax
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
# One graph trainer
trainer = Trainer(architecture=scratch,
loss=loss,
iterations=iterations,
analizer=None,
prefetch=False,
temp_dir=directory)
trainer.train(train_data_shuffler)
accuracy = validate_network(validation_data, validation_labels, directory)
assert accuracy > 80
del scratch
......@@ -62,6 +62,8 @@ class SiameseTrainer(Trainer):
## Analizer
analizer=ExperimentAnalizer(),
model_from_file="",
verbosity_level=2):
super(SiameseTrainer, self).__init__(
......@@ -86,6 +88,8 @@ class SiameseTrainer(Trainer):
## Analizer
analizer=analizer,
model_from_file=model_from_file,
verbosity_level=verbosity_level
)
......
......@@ -65,6 +65,9 @@ class Trainer(object):
## Analizer
analizer=SoftmaxAnalizer(),
### Pretrained model
model_from_file="",
verbosity_level=2):
if not isinstance(architecture, SequenceNetwork):
......@@ -107,6 +110,8 @@ class Trainer(object):
self.enqueue_op = None
self.global_step = None
self.model_from_file = model_from_file
bob.core.log.set_verbosity_level(logger, verbosity_level)
def __del__(self):
......@@ -289,6 +294,12 @@ class Trainer(object):
with tf.Session(config=config) as session:
tf.initialize_all_variables().run()
# Loading a pretrained model
if self.model_from_file != "":
logger.info("Loading pretrained model from {0}".format(self.model_from_file))
hdf5 = bob.io.base.HDF5File(self.model_from_file)
self.architecture.load_variables_only(hdf5, session)
if isinstance(train_data_shuffler, OnLineSampling):
train_data_shuffler.set_feature_extractor(self.architecture, session=session)
......
......@@ -62,6 +62,8 @@ class TripletTrainer(Trainer):
## Analizer
analizer=ExperimentAnalizer(),
model_from_file="",
verbosity_level=2):
super(TripletTrainer, self).__init__(
......@@ -85,6 +87,7 @@ class TripletTrainer(Trainer):
## Analizer
analizer=analizer,
model_from_file=model_from_file,
verbosity_level=verbosity_level
)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment