Trainer.py 5.86 KB
Newer Older
1 2 3 4 5 6 7 8 9
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf
from ..network import SequenceNetwork
10
import threading
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
11
import numpy
12 13
import os
import bob.io.base
14

15

16 17 18 19 20 21 22
class Trainer(object):

    def __init__(self,

                 architecture=None,
                 use_gpu=False,
                 loss=None,
23
                 temp_dir="",
24 25 26 27

                 ###### training options ##########
                 convergence_threshold = 0.01,
                 iterations=5000,
28
                 base_lr=0.001,
29
                 momentum=0.9,
30
                 weight_decay=0.95,
31 32 33 34 35 36 37

                 # The learning rate policy
                 snapshot=100):

        self.loss = loss
        self.loss_instance = None
        self.optimizer = None
38 39
        self.temp_dir=temp_dir

40 41 42 43 44 45 46 47 48 49 50

        self.architecture = architecture
        self.use_gpu = use_gpu

        self.iterations = iterations
        self.snapshot = snapshot
        self.base_lr = base_lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.convergence_threshold = convergence_threshold

51
    def train(self, train_data_shuffler, validation_data_shuffler=None):
52 53 54 55 56
        """
        Do the loop forward --> backward --|
                      ^--------------------|
        """

57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
        def start_thread():
            threads = []
            for n in range(1):
                t = threading.Thread(target=load_and_enqueue)
                t.daemon = True  # thread will close when parent quits
                t.start()
                threads.append(t)
            return threads

        def load_and_enqueue():
            """
            Injecting data in the place holder queue
            """

            #while not thread_pool.should_stop():
            for i in range(self.iterations):
73
                train_data, train_labels = train_data_shuffler.get_batch()
74 75 76 77 78 79 80

                feed_dict = {train_placeholder_data: train_data,
                             train_placeholder_labels: train_labels}

                session.run(enqueue_op, feed_dict=feed_dict)

        # Defining place holders
81 82 83 84
        train_placeholder_data, train_placeholder_labels = train_data_shuffler.get_placeholders_forprefetch(name="train")
        if validation_data_shuffler is not None:
            validation_placeholder_data, validation_placeholder_labels = \
                validation_data_shuffler.get_placeholders(name="validation")
85 86 87 88 89 90 91
        # Defining a placeholder queue for prefetching
        queue = tf.FIFOQueue(capacity=10,
                             dtypes=[tf.float32, tf.int64],
                             shapes=[train_placeholder_data.get_shape().as_list()[1:], []])

        # Fetching the place holders from the queue
        enqueue_op = queue.enqueue_many([train_placeholder_data, train_placeholder_labels])
92
        train_feature_batch, train_label_batch = queue.dequeue_many(train_data_shuffler.batch_size)
93 94 95 96 97 98

        # Creating the architecture for train and validation
        if not isinstance(self.architecture, SequenceNetwork):
            raise ValueError("The variable `architecture` must be an instance of "
                             "`bob.learn.tensorflow.network.SequenceNetwork`")

99
        # Creating graphs and defining the loss
100 101
        train_graph = self.architecture.compute_graph(train_feature_batch)
        loss_train = self.loss(train_graph, train_label_batch)
102 103 104 105 106
        train_prediction = tf.nn.softmax(train_graph)
        if validation_data_shuffler is not None:
            validation_graph = self.architecture.compute_graph(validation_placeholder_data)
            loss_validation = self.loss(validation_graph, validation_placeholder_labels)
            validation_prediction = tf.nn.softmax(validation_graph)
107

108 109 110
        batch = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(
            self.base_lr,  # Learning rate
111 112
            batch * train_data_shuffler.batch_size,
            train_data_shuffler.n_samples,
113 114
            self.weight_decay  # Decay step
        )
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
115
        optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_train,
116
                                                                              global_step=batch)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
117

118 119
        print("Initializing !!")
        # Training
120
        hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model.hdf5'), 'w')
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
121

122
        with tf.Session() as session:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
123

124
            tf.initialize_all_variables().run()
125 126 127 128 129 130 131

            # Start a thread to enqueue data asynchronously, and hide I/O latency.
            thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=thread_pool)

            threads = start_thread()

132
            train_writer = tf.train.SummaryWriter('./LOGS/train', session.graph)
133

134 135
            for step in range(self.iterations):

136 137 138 139 140
                _, l, lr, _ = session.run([optimizer, loss_train,
                                           learning_rate, train_prediction])

                if validation_data_shuffler is not None and step % self.snapshot == 0:
                    validation_data, validation_labels = validation_data_shuffler.get_batch()
141

142 143
                    feed_dict = {validation_placeholder_data: validation_data,
                                 validation_placeholder_labels: validation_labels}
144

145 146
                    l, predictions = session.run([loss_validation, validation_prediction], feed_dict=feed_dict)
                    accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == validation_labels) / predictions.shape[0]
147

148
                    print "Step {0}. Loss = {1}, Acc Validation={2}".format(step, l, accuracy)
149

150
            train_writer.close()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
151

152 153 154 155 156
            # now they should definetely stop
            thread_pool.request_stop()
            thread_pool.join(threads)
            self.architecture.save(hdf5)
            del hdf5
157