Trainer.py 5.86 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf
from ..network import SequenceNetwork
10
import threading
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
11
import numpy
12
13
import os
import bob.io.base
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
14

15

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
16
17
18
19
20
21
22
class Trainer(object):

    def __init__(self,

                 architecture=None,
                 use_gpu=False,
                 loss=None,
23
                 temp_dir="",
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
24
25
26
27

                 ###### training options ##########
                 convergence_threshold = 0.01,
                 iterations=5000,
28
                 base_lr=0.001,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
29
                 momentum=0.9,
30
                 weight_decay=0.95,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
31
32
33
34
35
36
37

                 # The learning rate policy
                 snapshot=100):

        self.loss = loss
        self.loss_instance = None
        self.optimizer = None
38
39
        self.temp_dir=temp_dir

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
40
41
42
43
44
45
46
47
48
49
50

        self.architecture = architecture
        self.use_gpu = use_gpu

        self.iterations = iterations
        self.snapshot = snapshot
        self.base_lr = base_lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.convergence_threshold = convergence_threshold

51
    def train(self, train_data_shuffler, validation_data_shuffler=None):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
52
53
54
55
56
        """
        Do the loop forward --> backward --|
                      ^--------------------|
        """

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        def start_thread():
            threads = []
            for n in range(1):
                t = threading.Thread(target=load_and_enqueue)
                t.daemon = True  # thread will close when parent quits
                t.start()
                threads.append(t)
            return threads

        def load_and_enqueue():
            """
            Injecting data in the place holder queue
            """

            #while not thread_pool.should_stop():
            for i in range(self.iterations):
73
                train_data, train_labels = train_data_shuffler.get_batch()
74
75
76
77
78
79
80

                feed_dict = {train_placeholder_data: train_data,
                             train_placeholder_labels: train_labels}

                session.run(enqueue_op, feed_dict=feed_dict)

        # Defining place holders
81
82
83
84
        train_placeholder_data, train_placeholder_labels = train_data_shuffler.get_placeholders_forprefetch(name="train")
        if validation_data_shuffler is not None:
            validation_placeholder_data, validation_placeholder_labels = \
                validation_data_shuffler.get_placeholders(name="validation")
85
86
87
88
89
90
91
        # Defining a placeholder queue for prefetching
        queue = tf.FIFOQueue(capacity=10,
                             dtypes=[tf.float32, tf.int64],
                             shapes=[train_placeholder_data.get_shape().as_list()[1:], []])

        # Fetching the place holders from the queue
        enqueue_op = queue.enqueue_many([train_placeholder_data, train_placeholder_labels])
92
        train_feature_batch, train_label_batch = queue.dequeue_many(train_data_shuffler.batch_size)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
93
94
95
96
97
98

        # Creating the architecture for train and validation
        if not isinstance(self.architecture, SequenceNetwork):
            raise ValueError("The variable `architecture` must be an instance of "
                             "`bob.learn.tensorflow.network.SequenceNetwork`")

99
        # Creating graphs and defining the loss
100
101
        train_graph = self.architecture.compute_graph(train_feature_batch)
        loss_train = self.loss(train_graph, train_label_batch)
102
103
104
105
106
        train_prediction = tf.nn.softmax(train_graph)
        if validation_data_shuffler is not None:
            validation_graph = self.architecture.compute_graph(validation_placeholder_data)
            loss_validation = self.loss(validation_graph, validation_placeholder_labels)
            validation_prediction = tf.nn.softmax(validation_graph)
107

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
108
109
110
        batch = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(
            self.base_lr,  # Learning rate
111
112
            batch * train_data_shuffler.batch_size,
            train_data_shuffler.n_samples,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
113
114
            self.weight_decay  # Decay step
        )
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
115
        optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_train,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
116
                                                                              global_step=batch)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
117

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
118
119
        print("Initializing !!")
        # Training
120
        hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model.hdf5'), 'w')
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
121

122
        with tf.Session() as session:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
123

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
124
            tf.initialize_all_variables().run()
125
126
127
128
129
130
131

            # Start a thread to enqueue data asynchronously, and hide I/O latency.
            thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=thread_pool)

            threads = start_thread()

132
            train_writer = tf.train.SummaryWriter('./LOGS/train', session.graph)
133

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
134
135
            for step in range(self.iterations):

136
137
138
139
140
                _, l, lr, _ = session.run([optimizer, loss_train,
                                           learning_rate, train_prediction])

                if validation_data_shuffler is not None and step % self.snapshot == 0:
                    validation_data, validation_labels = validation_data_shuffler.get_batch()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
141

142
143
                    feed_dict = {validation_placeholder_data: validation_data,
                                 validation_placeholder_labels: validation_labels}
144

145
146
                    l, predictions = session.run([loss_validation, validation_prediction], feed_dict=feed_dict)
                    accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == validation_labels) / predictions.shape[0]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
147

148
                    print "Step {0}. Loss = {1}, Acc Validation={2}".format(step, l, accuracy)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
149

150
            train_writer.close()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
151

152
153
154
155
156
            # now they should definetely stop
            thread_pool.request_stop()
            thread_pool.join(threads)
            self.architecture.save(hdf5)
            del hdf5
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
157