Trainer.py 6.16 KB
Newer Older
1 2 3 4 5 6 7 8 9
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf
from ..network import SequenceNetwork
10
import threading
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
11
import numpy
12 13
import os
import bob.io.base
14 15 16 17 18 19 20 21

class Trainer(object):

    def __init__(self,

                 architecture=None,
                 use_gpu=False,
                 loss=None,
22
                 temp_dir="",
23 24 25 26

                 ###### training options ##########
                 convergence_threshold = 0.01,
                 iterations=5000,
27
                 base_lr=0.001,
28
                 momentum=0.9,
29
                 weight_decay=0.95,
30 31 32 33 34 35 36

                 # The learning rate policy
                 snapshot=100):

        self.loss = loss
        self.loss_instance = None
        self.optimizer = None
37 38
        self.temp_dir=temp_dir

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

        self.architecture = architecture
        self.use_gpu = use_gpu

        self.iterations = iterations
        self.snapshot = snapshot
        self.base_lr = base_lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.convergence_threshold = convergence_threshold

    def train(self, data_shuffler):
        """
        Do the loop forward --> backward --|
                      ^--------------------|
        """

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
        def start_thread():
            threads = []
            for n in range(1):
                t = threading.Thread(target=load_and_enqueue)
                t.daemon = True  # thread will close when parent quits
                t.start()
                threads.append(t)
            return threads

        def load_and_enqueue():
            """
            Injecting data in the place holder queue
            """

            #while not thread_pool.should_stop():
            for i in range(self.iterations):
                train_data, train_labels = data_shuffler.get_batch()

                feed_dict = {train_placeholder_data: train_data,
                             train_placeholder_labels: train_labels}

                session.run(enqueue_op, feed_dict=feed_dict)

        # Defining place holders
        train_placeholder_data, train_placeholder_labels = data_shuffler.get_placeholders_forprefetch(name="train")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
81 82
        validation_placeholder_data, validation_placeholder_labels = data_shuffler.get_placeholders(name="validation",
                                                                                                    train_dataset=False)
83 84 85 86 87 88 89 90
        # Defining a placeholder queue for prefetching
        queue = tf.FIFOQueue(capacity=10,
                             dtypes=[tf.float32, tf.int64],
                             shapes=[train_placeholder_data.get_shape().as_list()[1:], []])

        # Fetching the place holders from the queue
        enqueue_op = queue.enqueue_many([train_placeholder_data, train_placeholder_labels])
        train_feature_batch, train_label_batch = queue.dequeue_many(data_shuffler.train_batch_size)
91 92 93 94 95 96

        # Creating the architecture for train and validation
        if not isinstance(self.architecture, SequenceNetwork):
            raise ValueError("The variable `architecture` must be an instance of "
                             "`bob.learn.tensorflow.network.SequenceNetwork`")

97 98 99
        # Creating graphs
        #train_graph = self.architecture.compute_graph(train_placeholder_data)
        train_graph = self.architecture.compute_graph(train_feature_batch)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
100 101
        validation_graph = self.architecture.compute_graph(validation_placeholder_data)

102 103 104
        # Defining the loss
        #loss_train = self.loss(train_graph, train_placeholder_labels)
        loss_train = self.loss(train_graph, train_label_batch)
105
        loss_validation = self.loss(validation_graph, validation_placeholder_labels)
106

107 108 109
        batch = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(
            self.base_lr,  # Learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
110
            batch * data_shuffler.train_batch_size,
111 112 113
            data_shuffler.train_data.shape[0],
            self.weight_decay  # Decay step
        )
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
114
        optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_train,
115
                                                                              global_step=batch)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
116

117
        train_prediction = tf.nn.softmax(train_graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
118
        validation_prediction = tf.nn.softmax(validation_graph)
119 120 121

        print("Initializing !!")
        # Training
122
        hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model.hdf5'), 'w')
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
123

124
        with tf.Session() as session:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
125

126
            tf.initialize_all_variables().run()
127 128 129 130 131 132 133 134 135 136

            # Start a thread to enqueue data asynchronously, and hide I/O latency.
            thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=thread_pool)

            threads = start_thread()

            #train_writer = tf.train.SummaryWriter('./LOGS/train',
            #                                      session.graph)

137 138
            for step in range(self.iterations):

139 140 141
                try:
                    _, l, lr, _ = session.run([optimizer, loss_train,
                                               learning_rate, train_prediction])
142

143 144 145 146 147 148 149
                    if step % self.snapshot == 0:
                        validation_data, validation_labels = data_shuffler.get_batch(train_dataset=False)
                        feed_dict = {validation_placeholder_data: validation_data,
                                     validation_placeholder_labels: validation_labels}##

                        l, predictions = session.run([loss_validation, validation_prediction], feed_dict=feed_dict)
                        accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == validation_labels) / predictions.shape[0]
150

151 152 153 154 155
                        print "Step {0}. Loss = {1}, Acc Validation={2}".format(step, l, accuracy)
                except:
                    print "ERROR"
                finally:
                    thread_pool.request_stop()
156

157
            #train_writer.close()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
158

159 160 161 162 163
            # now they should definetely stop
            thread_pool.request_stop()
            thread_pool.join(threads)
            self.architecture.save(hdf5)
            del hdf5
164