Trainer.py 4.78 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")
from ..DataShuffler import DataShuffler
import tensorflow as tf
from ..network import SequenceNetwork
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
11 12
import numpy
from bob.learn.tensorflow.layers import InputLayer
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53


class Trainer(object):

    def __init__(self,

                 architecture=None,
                 use_gpu=False,
                 loss=None,

                 ###### training options ##########
                 convergence_threshold = 0.01,
                 iterations=5000,
                 base_lr=0.00001,
                 momentum=0.9,
                 weight_decay=0.0005,

                 # The learning rate policy
                 snapshot=100):

        self.loss = loss
        self.loss_instance = None
        self.optimizer = None

        self.architecture = architecture
        self.use_gpu = use_gpu

        self.iterations = iterations
        self.snapshot = snapshot
        self.base_lr = base_lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.convergence_threshold = convergence_threshold

    def train(self, data_shuffler):
        """
        Do the loop forward --> backward --|
                      ^--------------------|
        """

        train_placeholder_data, train_placeholder_labels = data_shuffler.get_placeholders(name="train")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
54 55
        validation_placeholder_data, validation_placeholder_labels = data_shuffler.get_placeholders(name="validation",
                                                                                                    train_dataset=False)
56 57 58 59 60 61

        # Creating the architecture for train and validation
        if not isinstance(self.architecture, SequenceNetwork):
            raise ValueError("The variable `architecture` must be an instance of "
                             "`bob.learn.tensorflow.network.SequenceNetwork`")

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
62 63
        #input_layer = InputLayer(name="input", input_data=train_placeholder_data)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
64 65
        #import ipdb;
        #ipdb.set_trace();
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
66

67
        train_graph = self.architecture.compute_graph(train_placeholder_data)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
68 69 70 71 72

        validation_graph = self.architecture.compute_graph(validation_placeholder_data)

        loss_train = tf.reduce_mean(self.loss(train_graph, train_placeholder_labels))
        loss_validation = tf.reduce_mean(self.loss(validation_graph, validation_placeholder_labels))
73 74 75 76

        batch = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(
            self.base_lr,  # Learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
77
            batch * data_shuffler.train_batch_size,
78 79 80
            data_shuffler.train_data.shape[0],
            self.weight_decay  # Decay step
        )
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
81
        optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_train,
82
                                                                              global_step=batch)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
83 84


85
        train_prediction = tf.nn.softmax(train_graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
86
        validation_prediction = tf.nn.softmax(validation_graph)
87 88 89 90

        print("Initializing !!")
        # Training
        with tf.Session() as session:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
91 92 93 94

            train_writer = tf.train.SummaryWriter('./LOGS/train',
                                                  session.graph)

95 96 97
            tf.initialize_all_variables().run()
            for step in range(self.iterations):

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
98
                train_data, train_labels = data_shuffler.get_batch()
99 100 101 102

                feed_dict = {train_placeholder_data: train_data,
                             train_placeholder_labels: train_labels}

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
103
                _, l, lr, _ = session.run([optimizer, loss_train,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
104
                                            learning_rate, train_prediction], feed_dict=feed_dict)
105 106

                if step % self.snapshot == 0:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
107
                    validation_data, validation_labels = data_shuffler.get_batch(train_dataset=False)
108 109 110
                    feed_dict = {validation_placeholder_data: validation_data,
                                 validation_placeholder_labels: validation_labels}

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
111 112
                    #import ipdb;
                    #ipdb.set_trace();
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
113 114 115 116 117

                    l, predictions = session.run([loss_validation, validation_prediction], feed_dict=feed_dict)
                    accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == validation_labels) / predictions.shape[0]

                    print "Step {0}. Loss = {1}, Acc Validation={2}".format(step, l, accuracy)
118 119 120 121 122

                    #accuracy = util.evaluate_softmax(validation_data, validation_labels, session, validation_prediction,
                    #                                 validation_data_node)
                    #print("Step {0}. Loss = {1}, Lr={2}, Accuracy validation = {3}".format(step, l, lr, accuracy))
                    #sys.stdout.flush()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
123
            train_writer.close()