Trainer.py 4.15 KB
Newer Older
1 2 3 4 5 6 7 8 9
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf
from ..network import SequenceNetwork
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
10
import numpy
11 12 13 14 15 16 17 18 19 20 21 22 23


class Trainer(object):

    def __init__(self,

                 architecture=None,
                 use_gpu=False,
                 loss=None,

                 ###### training options ##########
                 convergence_threshold = 0.01,
                 iterations=5000,
24
                 base_lr=0.001,
25
                 momentum=0.9,
26
                 weight_decay=0.95,
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51

                 # The learning rate policy
                 snapshot=100):

        self.loss = loss
        self.loss_instance = None
        self.optimizer = None

        self.architecture = architecture
        self.use_gpu = use_gpu

        self.iterations = iterations
        self.snapshot = snapshot
        self.base_lr = base_lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.convergence_threshold = convergence_threshold

    def train(self, data_shuffler):
        """
        Do the loop forward --> backward --|
                      ^--------------------|
        """

        train_placeholder_data, train_placeholder_labels = data_shuffler.get_placeholders(name="train")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
52 53
        validation_placeholder_data, validation_placeholder_labels = data_shuffler.get_placeholders(name="validation",
                                                                                                    train_dataset=False)
54 55 56 57 58 59 60

        # Creating the architecture for train and validation
        if not isinstance(self.architecture, SequenceNetwork):
            raise ValueError("The variable `architecture` must be an instance of "
                             "`bob.learn.tensorflow.network.SequenceNetwork`")

        train_graph = self.architecture.compute_graph(train_placeholder_data)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
61 62
        validation_graph = self.architecture.compute_graph(validation_placeholder_data)

63 64
        loss_train = self.loss(train_graph, train_placeholder_labels)
        loss_validation = self.loss(validation_graph, validation_placeholder_labels)
65

66 67 68
        batch = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(
            self.base_lr,  # Learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
69
            batch * data_shuffler.train_batch_size,
70 71 72
            data_shuffler.train_data.shape[0],
            self.weight_decay  # Decay step
        )
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
73
        optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_train,
74
                                                                              global_step=batch)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
75

76
        train_prediction = tf.nn.softmax(train_graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
77
        validation_prediction = tf.nn.softmax(validation_graph)
78 79 80 81

        print("Initializing !!")
        # Training
        with tf.Session() as session:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
82 83 84 85

            train_writer = tf.train.SummaryWriter('./LOGS/train',
                                                  session.graph)

86 87 88
            tf.initialize_all_variables().run()
            for step in range(self.iterations):

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
89
                train_data, train_labels = data_shuffler.get_batch()
90 91 92 93

                feed_dict = {train_placeholder_data: train_data,
                             train_placeholder_labels: train_labels}

94
                _, l, lr, __ = session.run([optimizer, loss_train,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
95
                                            learning_rate, train_prediction], feed_dict=feed_dict)
96
                                            
97
                if step % self.snapshot == 0:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
98
                    validation_data, validation_labels = data_shuffler.get_batch(train_dataset=False)
99 100 101
                    feed_dict = {validation_placeholder_data: validation_data,
                                 validation_placeholder_labels: validation_labels}

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
102 103 104 105
                    l, predictions = session.run([loss_validation, validation_prediction], feed_dict=feed_dict)
                    accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == validation_labels) / predictions.shape[0]

                    print "Step {0}. Loss = {1}, Acc Validation={2}".format(step, l, accuracy)
106

107
            train_writer.close()