diff --git a/bob/learn/tensorflow/network/SequenceNetwork.py b/bob/learn/tensorflow/network/SequenceNetwork.py index 8c7ecf954da355ddf66f4a3bc00d773256acbff2..595a9283e9dadc74dac79b6818950755e9edf52b 100644 --- a/bob/learn/tensorflow/network/SequenceNetwork.py +++ b/bob/learn/tensorflow/network/SequenceNetwork.py @@ -225,17 +225,32 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): self.sequence_net[k].weights_initialization.use_gpu = state self.sequence_net[k].bias_initialization.use_gpu = state - def load(self, hdf5, shape=None, session=None, batch=1): + def load_variables_only(self, hdf5, session): """ - Load the network + Load the variables of the model + """ + hdf5.cd('/tensor_flow') + for k in self.sequence_net: + # TODO: IT IS NOT SMART TESTING ALONG THIS PAGE + if not isinstance(self.sequence_net[k], MaxPooling): + self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name)).eval(session=session) + session.run(self.sequence_net[k].W) + self.sequence_net[k].b.assign(hdf5.read(self.sequence_net[k].b.name)).eval(session=session) + session.run(self.sequence_net[k].b) + hdf5.cd("..") + + def load(self, hdf5, shape=None, session=None, batch=1, use_gpu=False): + """ + Load the network from scratch. + This will build the graphs **Parameters** hdf5: The saved network in the :py:class:`bob.io.base.HDF5File` format - - shape: Input shape of the network - - session: tensorflow `session <https://www.tensorflow.org/versions/r0.11/api_docs/python/client.html#Session>`_ + shape: Input shape of the network. If `None`, the internal shape will be assumed + session: An opened tensorflow `session <https://www.tensorflow.org/versions/r0.11/api_docs/python/client.html#Session>`_. If `None`, a new one will be opened + batch: The size of the batch + use_gpu: Load all the variables in the GPU? """ if session is None: @@ -249,7 +264,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): self.sequence_net = pickle.loads(hdf5.read('architecture')) self.deployment_shape = hdf5.read('deployment_shape') - self.turn_gpu_onoff(False) + self.turn_gpu_onoff(use_gpu) if shape is None: shape = self.deployment_shape @@ -259,17 +274,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): place_holder = tf.placeholder(tf.float32, shape=shape, name="load") self.compute_graph(place_holder) tf.initialize_all_variables().run(session=session) - - hdf5.cd('/tensor_flow') - for k in self.sequence_net: - # TODO: IT IS NOT SMART TESTING ALONG THIS PAGE - if not isinstance(self.sequence_net[k], MaxPooling): - #self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name)) - self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name)).eval(session=session) - session.run(self.sequence_net[k].W) - self.sequence_net[k].b.assign(hdf5.read(self.sequence_net[k].b.name)).eval(session=session) - session.run(self.sequence_net[k].b) - + self.load_variables_only(hdf5, session) """ def save(self, session, path, step=None): diff --git a/bob/learn/tensorflow/test/test_cnn_scratch.py b/bob/learn/tensorflow/test/test_cnn_scratch.py index 469c3def57a5e5fbbd544f7ceea223af66af31fd..66b813f846157133eb0cfe0bd34dd3e684a74460 100644 --- a/bob/learn/tensorflow/test/test_cnn_scratch.py +++ b/bob/learn/tensorflow/test/test_cnn_scratch.py @@ -26,20 +26,7 @@ iterations = 50 seed = 10 -def test_cnn_trainer_scratch(): - train_data, train_labels, validation_data, validation_labels = load_mnist() - train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1)) - - # Creating datashufflers - data_augmentation = ImageAugmentation() - train_data_shuffler = Memory(train_data, train_labels, - input_shape=[28, 28, 1], - batch_size=batch_size, - data_augmentation=data_augmentation) - validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1)) - - directory = "./temp/cnn" - +def scratch_network(): # Creating a random network scratch = SequenceNetwork() scratch.add(Conv2D(name="conv1", kernel_size=3, @@ -51,20 +38,11 @@ def test_cnn_trainer_scratch(): activation=None, weights_initialization=Xavier(seed=seed, use_gpu=False), bias_initialization=Constant(use_gpu=False))) - # Loss for the softmax - loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean) - # One graph trainer - trainer = Trainer(architecture=scratch, - loss=loss, - iterations=iterations, - analizer=None, - prefetch=False, - temp_dir=directory) - trainer.train(train_data_shuffler) + return scratch - del scratch +def validate_network(validation_data, validation_labels, directory): # Testing validation_data_shuffler = Memory(validation_data, validation_labels, input_shape=[28, 28, 1], @@ -79,7 +57,40 @@ def test_cnn_trainer_scratch(): predictions = scratch(data, session=session) accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == labels) / predictions.shape[0] - # At least 80% of accuracy - assert accuracy > 80. - shutil.rmtree(directory) + return accuracy + + +def test_cnn_trainer_scratch(): + train_data, train_labels, validation_data, validation_labels = load_mnist() + train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1)) + + # Creating datashufflers + data_augmentation = ImageAugmentation() + train_data_shuffler = Memory(train_data, train_labels, + input_shape=[28, 28, 1], + batch_size=batch_size, + data_augmentation=data_augmentation) + validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1)) + + directory = "./temp/cnn" + + # Create scratch network + scratch = scratch_network() + + # Loss for the softmax + loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean) + + # One graph trainer + trainer = Trainer(architecture=scratch, + loss=loss, + iterations=iterations, + analizer=None, + prefetch=False, + temp_dir=directory) + trainer.train(train_data_shuffler) + + accuracy = validate_network(validation_data, validation_labels, directory) + assert accuracy > 80 + del scratch + diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py index c7cd48d9a4408f61a6ddf2f31d57e04a25bc6512..7cf1f20628d0cc9d0c38ed45698b8b62d6763a81 100644 --- a/bob/learn/tensorflow/trainers/SiameseTrainer.py +++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py @@ -62,6 +62,8 @@ class SiameseTrainer(Trainer): ## Analizer analizer=ExperimentAnalizer(), + model_from_file="", + verbosity_level=2): super(SiameseTrainer, self).__init__( @@ -86,6 +88,8 @@ class SiameseTrainer(Trainer): ## Analizer analizer=analizer, + model_from_file=model_from_file, + verbosity_level=verbosity_level ) diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py index 6cce163cdb55a10bceef05ce0e9f5cb05f948f16..68538236b3d069ce7c53c19aba525a04ea83ddc3 100644 --- a/bob/learn/tensorflow/trainers/Trainer.py +++ b/bob/learn/tensorflow/trainers/Trainer.py @@ -65,6 +65,9 @@ class Trainer(object): ## Analizer analizer=SoftmaxAnalizer(), + ### Pretrained model + model_from_file="", + verbosity_level=2): if not isinstance(architecture, SequenceNetwork): @@ -107,6 +110,8 @@ class Trainer(object): self.enqueue_op = None self.global_step = None + self.model_from_file = model_from_file + bob.core.log.set_verbosity_level(logger, verbosity_level) def __del__(self): @@ -289,6 +294,12 @@ class Trainer(object): with tf.Session(config=config) as session: tf.initialize_all_variables().run() + # Loading a pretrained model + if self.model_from_file != "": + logger.info("Loading pretrained model from {0}".format(self.model_from_file)) + hdf5 = bob.io.base.HDF5File(self.model_from_file) + self.architecture.load_variables_only(hdf5, session) + if isinstance(train_data_shuffler, OnLineSampling): train_data_shuffler.set_feature_extractor(self.architecture, session=session) diff --git a/bob/learn/tensorflow/trainers/TripletTrainer.py b/bob/learn/tensorflow/trainers/TripletTrainer.py index 2379510f46e6a957bdb62d5220d1d9d17ffe6c2e..9452d5c8091c7b5c23ac4f5a1e08bff12d218cfa 100644 --- a/bob/learn/tensorflow/trainers/TripletTrainer.py +++ b/bob/learn/tensorflow/trainers/TripletTrainer.py @@ -62,6 +62,8 @@ class TripletTrainer(Trainer): ## Analizer analizer=ExperimentAnalizer(), + model_from_file="", + verbosity_level=2): super(TripletTrainer, self).__init__( @@ -85,6 +87,7 @@ class TripletTrainer(Trainer): ## Analizer analizer=analizer, + model_from_file=model_from_file, verbosity_level=verbosity_level )