SiameseTrainer.py 7.31 KB
Newer Older
1 2 3 4 5 6 7 8
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf
9
import threading
10
from ..analyzers import ExperimentAnalizer
11
from ..network import SequenceNetwork
12
import bob.io.base
13
from .Trainer import Trainer
14
import os
15
import sys
16

17
class SiameseTrainer(Trainer):
18 19

    def __init__(self,
20 21
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
22 23
                 use_gpu=False,
                 loss=None,
24
                 temp_dir="cnn",
25 26 27 28

                 # Learning rate
                 base_learning_rate=0.001,
                 weight_decay=0.9,
29 30

                 ###### training options ##########
31
                 convergence_threshold=0.01,
32 33 34
                 iterations=5000,
                 snapshot=100):

35 36 37 38 39 40 41 42 43 44 45 46
        super(SiameseTrainer, self).__init__(
            architecture=architecture,
            optimizer=optimizer,
            use_gpu=use_gpu,
            loss=loss,
            temp_dir=temp_dir,
            base_learning_rate=base_learning_rate,
            weight_decay=weight_decay,
            convergence_threshold=convergence_threshold,
            iterations=iterations,
            snapshot=snapshot
        )
47

48
    def train(self, train_data_shuffler, validation_data_shuffler=None):
49 50 51 52 53
        """
        Do the loop forward --> backward --|
                      ^--------------------|
        """

54 55 56 57 58 59 60 61 62 63 64 65 66
        def start_thread():
            threads = []
            for n in range(1):
                t = threading.Thread(target=load_and_enqueue)
                t.daemon = True  # thread will close when parent quits
                t.start()
                threads.append(t)
            return threads

        def load_and_enqueue():
            """
            Injecting data in the place holder queue
            """
67
            # for i in range(self.iterations+5):
68
            while not thread_pool.should_stop():
69 70 71 72 73 74 75 76
                batch_left, batch_right, labels = train_data_shuffler.get_pair()

                feed_dict = {train_placeholder_left_data: batch_left,
                             train_placeholder_right_data: batch_right,
                             train_placeholder_labels: labels}

                session.run(enqueue_op, feed_dict=feed_dict)

77 78 79 80 81 82 83 84
        # TODO: find an elegant way to provide this as a parameter of the trainer
        learning_rate = tf.train.exponential_decay(
            self.base_learning_rate,  # Learning rate
            train_data_shuffler.batch_size,
            train_data_shuffler.n_samples,
            self.weight_decay  # Decay step
        )

85 86
        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
87

88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
        # Creating two graphs
        train_placeholder_left_data, train_placeholder_labels = train_data_shuffler.\
            get_placeholders_forprefetch(name="train_left")
        train_placeholder_right_data, _ = train_data_shuffler.get_placeholders(name="train_right")

        # Defining a placeholder queue for prefetching
        queue = tf.FIFOQueue(capacity=100,
                             dtypes=[tf.float32, tf.float32, tf.int64],
                             shapes=[train_placeholder_left_data.get_shape().as_list()[1:],
                                     train_placeholder_right_data.get_shape().as_list()[1:],
                                     []])
        # Fetching the place holders from the queue
        enqueue_op = queue.enqueue_many([train_placeholder_left_data,
                                         train_placeholder_right_data,
                                         train_placeholder_labels])
        train_left_feature_batch, train_right_label_batch, train_labels_batch = \
            queue.dequeue_many(train_data_shuffler.batch_size)
105 106 107 108 109 110

        # Creating the architecture for train and validation
        if not isinstance(self.architecture, SequenceNetwork):
            raise ValueError("The variable `architecture` must be an instance of "
                             "`bob.learn.tensorflow.network.SequenceNetwork`")

111 112 113
        # Creating the siamese graph
        train_left_graph = self.architecture.compute_graph(train_left_feature_batch)
        train_right_graph = self.architecture.compute_graph(train_right_label_batch)
114

115
        loss_train, between_class, within_class = self.loss(train_labels_batch,
116
                                                            train_left_graph,
117
                                                            train_right_graph)
118

119
        # Preparing the optimizer
120
        step = tf.Variable(0)
121
        self.optimizer._learning_rate = learning_rate
122 123 124
        optimizer = self.optimizer.minimize(loss_train, global_step=step)
        #optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.99, use_locking=False,
        #                                       name='Momentum').minimize(loss_train, global_step=step)
125 126 127

        print("Initializing !!")
        # Training
128
        hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model.hdf5'), 'w')
129

130
        with tf.Session() as session:
131
            if validation_data_shuffler is not None:
132
                analizer = ExperimentAnalizer(validation_data_shuffler, self.architecture, session)
133 134 135 136 137 138 139

            tf.initialize_all_variables().run()

            # Start a thread to enqueue data asynchronously, and hide I/O latency.
            thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=thread_pool)
            threads = start_thread()
140

141
            # TENSOR BOARD SUMMARY
142
            train_writer = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'LOGS'), session.graph)
143

144
            # Siamese specific summary
145 146 147 148 149 150
            tf.scalar_summary('loss', loss_train)
            tf.scalar_summary('between_class', between_class)
            tf.scalar_summary('within_class', within_class)
            tf.scalar_summary('lr', learning_rate)
            merged = tf.merge_all_summaries()

151 152 153 154
            # Architecture summary
            self.architecture.generate_summaries()
            merged_validation = tf.merge_all_summaries()

155 156
            for step in range(self.iterations):

157 158 159
                _, l, lr, summary = session.run(
                    [optimizer, loss_train, learning_rate, merged])
                #_, l, lr,b,w, summary = session.run([optimizer, loss_train, learning_rate,between_class,within_class, merged])
160
                #_, l, lr= session.run([optimizer, loss_train, learning_rate])
161
                train_writer.add_summary(summary, step)
162 163
                #print str(step) + " loss: {0}, bc: {1}, wc: {2}".format(l, b, w)
                #print str(step) + " loss: {0}".format(l)
164
                sys.stdout.flush()
165
                #import ipdb; ipdb.set_trace();
166

167
                if validation_data_shuffler is not None and step % self.snapshot == 0:
168 169
                    print str(step)
                    sys.stdout.flush()
170 171 172 173 174 175

                    summary = session.run(merged_validation)
                    train_writer.add_summary(summary, step)

                    summary = analizer()
                    train_writer.add_summary(summary, step)
176

177
            print("#######DONE##########")
178 179
            self.architecture.save(hdf5)
            del hdf5
180
            train_writer.close()
181 182 183

            thread_pool.request_stop()
            thread_pool.join(threads)