Trainer.py 4.62 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
8
9
10
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")
from ..DataShuffler import DataShuffler
import tensorflow as tf
from ..network import SequenceNetwork
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
11
12
import numpy
from bob.learn.tensorflow.layers import InputLayer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53


class Trainer(object):

    def __init__(self,

                 architecture=None,
                 use_gpu=False,
                 loss=None,

                 ###### training options ##########
                 convergence_threshold = 0.01,
                 iterations=5000,
                 base_lr=0.00001,
                 momentum=0.9,
                 weight_decay=0.0005,

                 # The learning rate policy
                 snapshot=100):

        self.loss = loss
        self.loss_instance = None
        self.optimizer = None

        self.architecture = architecture
        self.use_gpu = use_gpu

        self.iterations = iterations
        self.snapshot = snapshot
        self.base_lr = base_lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.convergence_threshold = convergence_threshold

    def train(self, data_shuffler):
        """
        Do the loop forward --> backward --|
                      ^--------------------|
        """

        train_placeholder_data, train_placeholder_labels = data_shuffler.get_placeholders(name="train")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
54
55
        validation_placeholder_data, validation_placeholder_labels = data_shuffler.get_placeholders(name="validation",
                                                                                                    train_dataset=False)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
56
57
58
59
60
61

        # Creating the architecture for train and validation
        if not isinstance(self.architecture, SequenceNetwork):
            raise ValueError("The variable `architecture` must be an instance of "
                             "`bob.learn.tensorflow.network.SequenceNetwork`")

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
62
63
64
65
66
        #input_layer = InputLayer(name="input", input_data=train_placeholder_data)

        import ipdb;
        ipdb.set_trace();

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
67
        train_graph = self.architecture.compute_graph(train_placeholder_data)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
68
69
70
71
72

        validation_graph = self.architecture.compute_graph(validation_placeholder_data)

        loss_train = tf.reduce_mean(self.loss(train_graph, train_placeholder_labels))
        loss_validation = tf.reduce_mean(self.loss(validation_graph, validation_placeholder_labels))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
73
74
75
76

        batch = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(
            self.base_lr,  # Learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
77
            batch * data_shuffler.train_batch_size,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
78
79
80
            data_shuffler.train_data.shape[0],
            self.weight_decay  # Decay step
        )
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
81
        optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_train,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
82
83
                                                                              global_step=batch)
        train_prediction = tf.nn.softmax(train_graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
84
        validation_prediction = tf.nn.softmax(validation_graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
85
86
87
88
89
90
91

        print("Initializing !!")
        # Training
        with tf.Session() as session:
            tf.initialize_all_variables().run()
            for step in range(self.iterations):

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
92
                train_data, train_labels = data_shuffler.get_batch()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
93
94
95
96

                feed_dict = {train_placeholder_data: train_data,
                             train_placeholder_labels: train_labels}

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
97
98
                _, l, lr, _ = session.run([optimizer, loss_train,
                                          learning_rate, train_prediction], feed_dict=feed_dict)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
99
100

                if step % self.snapshot == 0:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
101
                    validation_data, validation_labels = data_shuffler.get_batch(train_dataset=False)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
102
103
104
                    feed_dict = {validation_placeholder_data: validation_data,
                                 validation_placeholder_labels: validation_labels}

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
105
106
107
108
109
110
111
                    import ipdb;
                    ipdb.set_trace();

                    l, predictions = session.run([loss_validation, validation_prediction], feed_dict=feed_dict)
                    accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == validation_labels) / predictions.shape[0]

                    print "Step {0}. Loss = {1}, Acc Validation={2}".format(step, l, accuracy)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
112
113
114
115
116

                    #accuracy = util.evaluate_softmax(validation_data, validation_labels, session, validation_prediction,
                    #                                 validation_data_node)
                    #print("Step {0}. Loss = {1}, Lr={2}, Accuracy validation = {3}".format(step, l, lr, accuracy))
                    #sys.stdout.flush()