Trainer.py 4.15 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf
from ..network import SequenceNetwork
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
10
import numpy
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
11
12
13
14
15
16
17
18
19
20
21
22
23


class Trainer(object):

    def __init__(self,

                 architecture=None,
                 use_gpu=False,
                 loss=None,

                 ###### training options ##########
                 convergence_threshold = 0.01,
                 iterations=5000,
24
                 base_lr=0.001,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
25
                 momentum=0.9,
26
                 weight_decay=0.95,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

                 # The learning rate policy
                 snapshot=100):

        self.loss = loss
        self.loss_instance = None
        self.optimizer = None

        self.architecture = architecture
        self.use_gpu = use_gpu

        self.iterations = iterations
        self.snapshot = snapshot
        self.base_lr = base_lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.convergence_threshold = convergence_threshold

    def train(self, data_shuffler):
        """
        Do the loop forward --> backward --|
                      ^--------------------|
        """

        train_placeholder_data, train_placeholder_labels = data_shuffler.get_placeholders(name="train")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
52
53
        validation_placeholder_data, validation_placeholder_labels = data_shuffler.get_placeholders(name="validation",
                                                                                                    train_dataset=False)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
54
55
56
57
58
59
60

        # Creating the architecture for train and validation
        if not isinstance(self.architecture, SequenceNetwork):
            raise ValueError("The variable `architecture` must be an instance of "
                             "`bob.learn.tensorflow.network.SequenceNetwork`")

        train_graph = self.architecture.compute_graph(train_placeholder_data)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
61
62
        validation_graph = self.architecture.compute_graph(validation_placeholder_data)

63
64
        loss_train = self.loss(train_graph, train_placeholder_labels)
        loss_validation = self.loss(validation_graph, validation_placeholder_labels)
65

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
66
67
68
        batch = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(
            self.base_lr,  # Learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
69
            batch * data_shuffler.train_batch_size,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
70
71
72
            data_shuffler.train_data.shape[0],
            self.weight_decay  # Decay step
        )
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
73
        optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_train,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
74
                                                                              global_step=batch)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
75

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
76
        train_prediction = tf.nn.softmax(train_graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
77
        validation_prediction = tf.nn.softmax(validation_graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
78
79
80
81

        print("Initializing !!")
        # Training
        with tf.Session() as session:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
82
83
84
85

            train_writer = tf.train.SummaryWriter('./LOGS/train',
                                                  session.graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
86
87
88
            tf.initialize_all_variables().run()
            for step in range(self.iterations):

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
89
                train_data, train_labels = data_shuffler.get_batch()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
90
91
92
93

                feed_dict = {train_placeholder_data: train_data,
                             train_placeholder_labels: train_labels}

94
                _, l, lr, __ = session.run([optimizer, loss_train,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
95
                                            learning_rate, train_prediction], feed_dict=feed_dict)
96
                                            
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
97
                if step % self.snapshot == 0:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
98
                    validation_data, validation_labels = data_shuffler.get_batch(train_dataset=False)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
99
100
101
                    feed_dict = {validation_placeholder_data: validation_data,
                                 validation_placeholder_labels: validation_labels}

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
102
103
104
105
                    l, predictions = session.run([loss_validation, validation_prediction], feed_dict=feed_dict)
                    accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == validation_labels) / predictions.shape[0]

                    print "Step {0}. Loss = {1}, Acc Validation={2}".format(step, l, accuracy)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
106

107
            train_writer.close()