SiameseTrainer.py 5.08 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf
from ..analyzers import Analizer
from ..network import SequenceNetwork
11 12
import bob.io.base
import os
13 14 15 16 17 18 19 20 21


class SiameseTrainer(object):

    def __init__(self,

                 architecture=None,
                 use_gpu=False,
                 loss=None,
22 23
                 temp_dir="",
                 save_intermediate=False,
24 25 26 27 28 29 30 31 32 33 34 35 36 37

                 ###### training options ##########
                 convergence_threshold = 0.01,
                 iterations=5000,
                 base_lr=0.001,
                 momentum=0.9,
                 weight_decay=0.95,

                 # The learning rate policy
                 snapshot=100):

        self.loss = loss
        self.loss_instance = None
        self.optimizer = None
38 39
        self.temp_dir = temp_dir
        self.save_intermediate = save_intermediate
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

        self.architecture = architecture
        self.use_gpu = use_gpu

        self.iterations = iterations
        self.snapshot = snapshot
        self.base_lr = base_lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.convergence_threshold = convergence_threshold

    def train(self, data_shuffler):
        """
        Do the loop forward --> backward --|
                      ^--------------------|
        """

57
        bob.io.base.create_directories_safe(os.path.join(self.temp_dir, 'OUTPUT'))
58 59
        train_placeholder_left_data, train_placeholder_labels = data_shuffler.get_placeholders(name="train_left")
        train_placeholder_right_data, _ = data_shuffler.get_placeholders(name="train_right")
60
        # feature_placeholder, _ = data_shuffler.get_placeholders(name="feature", train_dataset=False)
61 62 63 64 65 66 67 68 69 70 71 72

        #validation_placeholder_data, validation_placeholder_labels = data_shuffler.get_placeholders(name="validation",
        #                                                                                            train_dataset=False)

        # Creating the architecture for train and validation
        if not isinstance(self.architecture, SequenceNetwork):
            raise ValueError("The variable `architecture` must be an instance of "
                             "`bob.learn.tensorflow.network.SequenceNetwork`")

        train_left_graph = self.architecture.compute_graph(train_placeholder_left_data)
        train_right_graph = self.architecture.compute_graph(train_placeholder_right_data)

73
        loss_train, within_class, between_class = self.loss(train_placeholder_labels,
74
                                                            train_left_graph,
75
                                                            train_right_graph)
76 77 78 79 80 81 82 83

        batch = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(
            self.base_lr,  # Learning rate
            batch * data_shuffler.train_batch_size,
            data_shuffler.train_data.shape[0],
            self.weight_decay  # Decay step
        )
84 85 86 87
        #optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_train,
        #                                                                      global_step=batch)
        optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.99, use_locking=False,
                                               name='Momentum').minimize(loss_train, global_step=batch)
88 89 90 91 92 93 94

        #train_prediction = tf.nn.softmax(train_graph)
        #validation_prediction = tf.nn.softmax(validation_graph)


        print("Initializing !!")
        # Training
95
        hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model.hdf5'), 'w')
96
        with tf.Session() as session:
97
            analizer = Analizer(data_shuffler, self.architecture, session)
98

99
            train_writer = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'LOGS'), session.graph)
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

            # Tensorboard data
            tf.scalar_summary('loss', loss_train)
            tf.scalar_summary('between_class', between_class)
            tf.scalar_summary('within_class', within_class)
            tf.scalar_summary('lr', learning_rate)
            merged = tf.merge_all_summaries()

            tf.initialize_all_variables().run()
            for step in range(self.iterations):

                batch_left, batch_right, labels = data_shuffler.get_pair()

                feed_dict = {train_placeholder_left_data: batch_left,
                             train_placeholder_right_data: batch_right,
                             train_placeholder_labels: labels}

                _, l, lr, summary = session.run([optimizer, loss_train, learning_rate, merged], feed_dict=feed_dict)
                train_writer.add_summary(summary, step)

                if step % self.snapshot == 0:
                    analizer()
122
                    if self.save_intermediate:
123
                        self.architecture.save(hdf5, step)
124
                    print str(step) + " - " + str(analizer.eer[-1])
125

126 127
            self.architecture.save(hdf5)
            del hdf5
128
            train_writer.close()