Trainer.py 4.78 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
8
9
10
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")
from ..DataShuffler import DataShuffler
import tensorflow as tf
from ..network import SequenceNetwork
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
11
12
import numpy
from bob.learn.tensorflow.layers import InputLayer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53


class Trainer(object):

    def __init__(self,

                 architecture=None,
                 use_gpu=False,
                 loss=None,

                 ###### training options ##########
                 convergence_threshold = 0.01,
                 iterations=5000,
                 base_lr=0.00001,
                 momentum=0.9,
                 weight_decay=0.0005,

                 # The learning rate policy
                 snapshot=100):

        self.loss = loss
        self.loss_instance = None
        self.optimizer = None

        self.architecture = architecture
        self.use_gpu = use_gpu

        self.iterations = iterations
        self.snapshot = snapshot
        self.base_lr = base_lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.convergence_threshold = convergence_threshold

    def train(self, data_shuffler):
        """
        Do the loop forward --> backward --|
                      ^--------------------|
        """

        train_placeholder_data, train_placeholder_labels = data_shuffler.get_placeholders(name="train")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
54
55
        validation_placeholder_data, validation_placeholder_labels = data_shuffler.get_placeholders(name="validation",
                                                                                                    train_dataset=False)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
56
57
58
59
60
61

        # Creating the architecture for train and validation
        if not isinstance(self.architecture, SequenceNetwork):
            raise ValueError("The variable `architecture` must be an instance of "
                             "`bob.learn.tensorflow.network.SequenceNetwork`")

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
62
63
        #input_layer = InputLayer(name="input", input_data=train_placeholder_data)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
64
65
        #import ipdb;
        #ipdb.set_trace();
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
66

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
67
        train_graph = self.architecture.compute_graph(train_placeholder_data)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
68
69
70
71
72

        validation_graph = self.architecture.compute_graph(validation_placeholder_data)

        loss_train = tf.reduce_mean(self.loss(train_graph, train_placeholder_labels))
        loss_validation = tf.reduce_mean(self.loss(validation_graph, validation_placeholder_labels))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
73
74
75
76

        batch = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(
            self.base_lr,  # Learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
77
            batch * data_shuffler.train_batch_size,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
78
79
80
            data_shuffler.train_data.shape[0],
            self.weight_decay  # Decay step
        )
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
81
        optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_train,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
82
                                                                              global_step=batch)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
83
84


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
85
        train_prediction = tf.nn.softmax(train_graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
86
        validation_prediction = tf.nn.softmax(validation_graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
87
88
89
90

        print("Initializing !!")
        # Training
        with tf.Session() as session:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
91
92
93
94

            train_writer = tf.train.SummaryWriter('./LOGS/train',
                                                  session.graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
95
96
97
            tf.initialize_all_variables().run()
            for step in range(self.iterations):

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
98
                train_data, train_labels = data_shuffler.get_batch()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
99
100
101
102

                feed_dict = {train_placeholder_data: train_data,
                             train_placeholder_labels: train_labels}

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
103
                _, l, lr, _ = session.run([optimizer, loss_train,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
104
                                            learning_rate, train_prediction], feed_dict=feed_dict)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
105
106

                if step % self.snapshot == 0:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
107
                    validation_data, validation_labels = data_shuffler.get_batch(train_dataset=False)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
108
109
110
                    feed_dict = {validation_placeholder_data: validation_data,
                                 validation_placeholder_labels: validation_labels}

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
111
112
                    #import ipdb;
                    #ipdb.set_trace();
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
113
114
115
116
117

                    l, predictions = session.run([loss_validation, validation_prediction], feed_dict=feed_dict)
                    accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == validation_labels) / predictions.shape[0]

                    print "Step {0}. Loss = {1}, Acc Validation={2}".format(step, l, accuracy)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
118
119
120
121
122

                    #accuracy = util.evaluate_softmax(validation_data, validation_labels, session, validation_prediction,
                    #                                 validation_data_node)
                    #print("Step {0}. Loss = {1}, Lr={2}, Accuracy validation = {3}".format(step, l, lr, accuracy))
                    #sys.stdout.flush()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
123
            train_writer.close()