Trainer.py 6.16 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf
from ..network import SequenceNetwork
10
import threading
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
11
import numpy
12
13
import os
import bob.io.base
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
14
15
16
17
18
19
20
21

class Trainer(object):

    def __init__(self,

                 architecture=None,
                 use_gpu=False,
                 loss=None,
22
                 temp_dir="",
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
23
24
25
26

                 ###### training options ##########
                 convergence_threshold = 0.01,
                 iterations=5000,
27
                 base_lr=0.001,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
28
                 momentum=0.9,
29
                 weight_decay=0.95,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
30
31
32
33
34
35
36

                 # The learning rate policy
                 snapshot=100):

        self.loss = loss
        self.loss_instance = None
        self.optimizer = None
37
38
        self.temp_dir=temp_dir

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

        self.architecture = architecture
        self.use_gpu = use_gpu

        self.iterations = iterations
        self.snapshot = snapshot
        self.base_lr = base_lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.convergence_threshold = convergence_threshold

    def train(self, data_shuffler):
        """
        Do the loop forward --> backward --|
                      ^--------------------|
        """

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        def start_thread():
            threads = []
            for n in range(1):
                t = threading.Thread(target=load_and_enqueue)
                t.daemon = True  # thread will close when parent quits
                t.start()
                threads.append(t)
            return threads

        def load_and_enqueue():
            """
            Injecting data in the place holder queue
            """

            #while not thread_pool.should_stop():
            for i in range(self.iterations):
                train_data, train_labels = data_shuffler.get_batch()

                feed_dict = {train_placeholder_data: train_data,
                             train_placeholder_labels: train_labels}

                session.run(enqueue_op, feed_dict=feed_dict)

        # Defining place holders
        train_placeholder_data, train_placeholder_labels = data_shuffler.get_placeholders_forprefetch(name="train")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
81
82
        validation_placeholder_data, validation_placeholder_labels = data_shuffler.get_placeholders(name="validation",
                                                                                                    train_dataset=False)
83
84
85
86
87
88
89
90
        # Defining a placeholder queue for prefetching
        queue = tf.FIFOQueue(capacity=10,
                             dtypes=[tf.float32, tf.int64],
                             shapes=[train_placeholder_data.get_shape().as_list()[1:], []])

        # Fetching the place holders from the queue
        enqueue_op = queue.enqueue_many([train_placeholder_data, train_placeholder_labels])
        train_feature_batch, train_label_batch = queue.dequeue_many(data_shuffler.train_batch_size)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
91
92
93
94
95
96

        # Creating the architecture for train and validation
        if not isinstance(self.architecture, SequenceNetwork):
            raise ValueError("The variable `architecture` must be an instance of "
                             "`bob.learn.tensorflow.network.SequenceNetwork`")

97
98
99
        # Creating graphs
        #train_graph = self.architecture.compute_graph(train_placeholder_data)
        train_graph = self.architecture.compute_graph(train_feature_batch)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
100
101
        validation_graph = self.architecture.compute_graph(validation_placeholder_data)

102
103
104
        # Defining the loss
        #loss_train = self.loss(train_graph, train_placeholder_labels)
        loss_train = self.loss(train_graph, train_label_batch)
105
        loss_validation = self.loss(validation_graph, validation_placeholder_labels)
106

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
107
108
109
        batch = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(
            self.base_lr,  # Learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
110
            batch * data_shuffler.train_batch_size,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
111
112
113
            data_shuffler.train_data.shape[0],
            self.weight_decay  # Decay step
        )
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
114
        optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_train,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
115
                                                                              global_step=batch)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
116

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
117
        train_prediction = tf.nn.softmax(train_graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
118
        validation_prediction = tf.nn.softmax(validation_graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
119
120
121

        print("Initializing !!")
        # Training
122
        hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model.hdf5'), 'w')
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
123

124
        with tf.Session() as session:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
125

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
126
            tf.initialize_all_variables().run()
127
128
129
130
131
132
133
134
135
136

            # Start a thread to enqueue data asynchronously, and hide I/O latency.
            thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=thread_pool)

            threads = start_thread()

            #train_writer = tf.train.SummaryWriter('./LOGS/train',
            #                                      session.graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
137
138
            for step in range(self.iterations):

139
140
141
                try:
                    _, l, lr, _ = session.run([optimizer, loss_train,
                                               learning_rate, train_prediction])
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
142

143
144
145
146
147
148
149
                    if step % self.snapshot == 0:
                        validation_data, validation_labels = data_shuffler.get_batch(train_dataset=False)
                        feed_dict = {validation_placeholder_data: validation_data,
                                     validation_placeholder_labels: validation_labels}##

                        l, predictions = session.run([loss_validation, validation_prediction], feed_dict=feed_dict)
                        accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == validation_labels) / predictions.shape[0]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
150

151
152
153
154
155
                        print "Step {0}. Loss = {1}, Acc Validation={2}".format(step, l, accuracy)
                except:
                    print "ERROR"
                finally:
                    thread_pool.request_stop()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
156

157
            #train_writer.close()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
158

159
160
161
162
163
            # now they should definetely stop
            thread_pool.request_stop()
            thread_pool.join(threads)
            self.architecture.save(hdf5)
            del hdf5
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
164