Trainer.py 9.78 KB
Newer Older
1 2 3 4 5 6 7 8 9
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf
from ..network import SequenceNetwork
10
import threading
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
11
import numpy
12 13
import os
import bob.io.base
14
import bob.core
15
from ..analyzers import SoftmaxAnalizer
16
from tensorflow.core.framework import summary_pb2
17

18
logger = bob.core.log.setup("bob.learn.tensorflow")
19

20 21 22
class Trainer(object):

    def __init__(self,
23 24
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
25 26
                 use_gpu=False,
                 loss=None,
27
                 temp_dir="cnn",
28

29 30 31 32
                 # Learning rate
                 base_learning_rate=0.001,
                 weight_decay=0.9,

33
                 ###### training options ##########
34
                 convergence_threshold=0.01,
35
                 iterations=5000,
36 37
                 snapshot=100,
                 prefetch=False,
38 39 40 41 42

                 ## Analizer
                 analizer = SoftmaxAnalizer(),


43
                 verbosity_level=2):
44
        """
45

46 47 48 49 50 51 52 53 54 55 56 57
        **Parameters**
          architecture: The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`
          optimizer: One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html
          use_gpu: Use GPUs in the training
          loss: Loss
          temp_dir:
          iterations:
          snapshot:
          convergence_threshold:
        """
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
58 59

        self.architecture = architecture
60
        self.optimizer_class = optimizer
61
        self.use_gpu = use_gpu
62 63 64 65 66
        self.loss = loss
        self.temp_dir = temp_dir

        self.base_learning_rate = base_learning_rate
        self.weight_decay = weight_decay
67 68 69 70

        self.iterations = iterations
        self.snapshot = snapshot
        self.convergence_threshold = convergence_threshold
71
        self.prefetch = prefetch
72

73 74 75 76 77 78 79 80 81 82 83 84 85
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.learning_rate = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

86 87 88 89 90 91
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None

92 93 94
        bob.core.log.set_verbosity_level(logger, verbosity_level)

    def compute_graph(self, data_shuffler, name=""):
95
        """
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
        Computes the graph for the trainer.

        ** Parameters **

            data_shuffler: Data shuffler
            name: Name of the graph
        """

        # Defining place holders
        if self.prefetch:
            placeholder_data, placeholder_labels = data_shuffler.get_placeholders_forprefetch(name=name)

            #if validation_data_shuffler is not None:
            #    validation_placeholder_data, validation_placeholder_labels = \
            #        validation_data_shuffler.get_placeholders(name="validation")

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
118
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
            feature_batch, label_batch = data_shuffler.get_placeholders(name=name)

        # Creating graphs and defining the loss
        network_graph = self.architecture.compute_graph(feature_batch)
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
        Computes the feed_dict for the graph

        ** Parameters **

            data_shuffler:

142
        """
143 144 145 146 147 148 149 150 151
        data, labels = data_shuffler.get_batch()
        data_placeholder, label_placeholder = data_shuffler.get_placeholders()

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

    def __fit(self, session, step):
        if self.prefetch:
152 153
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train])
154 155 156 157 158
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train], feed_dict=feed_dict)

159 160
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181

    def __compute_validation(self, session, data_shuffler, step):

        if self.validation_summary_writter is None:
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph)

        self.validation_graph = self.compute_graph(data_shuffler, name="validation")
        feed_dict = self.get_feed_dict(data_shuffler)
        l = session.run(self.validation_graph, feed_dict=feed_dict)

        summaries = []
        summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

    def __create_general_summary(self):
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

182
    def start_thread(self, session):
183 184
        threads = []
        for n in range(1):
185
            t = threading.Thread(target=self.load_and_enqueue, args=(session, ))
186 187 188 189
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
190

191 192
    def load_and_enqueue(self, session):
        """
193
        Injecting data in the place holder queue
194
        """
195

196 197 198
        while not self.thread_pool.should_stop():
            train_data, train_labels = self.train_data_shuffler.get_batch()
            train_placeholder_data, train_placeholder_labels = self.train_data_shuffler.get_placeholders()
199

200 201 202
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

203
            session.run(self.enqueue_op, feed_dict=feed_dict)
204 205 206 207 208 209 210 211 212 213

    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
        Do the loop forward --> backward --|
                      ^--------------------|
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
214

215
        # TODO: find an elegant way to provide this as a parameter of the trainer
216
        self.learning_rate = tf.train.exponential_decay(
217 218 219 220 221 222
            self.base_learning_rate,  # Learning rate
            train_data_shuffler.batch_size,
            train_data_shuffler.n_samples,
            self.weight_decay  # Decay step
        )

223
        self.training_graph = self.compute_graph(train_data_shuffler, name="train")
224

225
        # Preparing the optimizer
226 227
        self.optimizer_class._learning_rate = self.learning_rate
        self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=tf.Variable(0))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
228

229
        # Train summary
230 231 232
        self.summaries_train = self.__create_general_summary()

        logger.info("Initializing !!")
233
        # Training
234
        hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model.hdf5'), 'w')
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
235

236
        with tf.Session() as session:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
237

238
            tf.initialize_all_variables().run()
239 240

            # Start a thread to enqueue data asynchronously, and hide I/O latency.
241 242 243 244
            if self.prefetch:
                self.thread_pool = tf.train.Coordinator()
                tf.train.start_queue_runners(coord=self.thread_pool)
                threads = self.start_thread(session)
245

246
            # TENSOR BOARD SUMMARY
247
            self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph)
248

249
            for step in range(self.iterations):
250
                self.__fit(session, step)
251
                if validation_data_shuffler is not None and step % self.snapshot == 0:
252 253
                    self.__compute_validation(session, validation_data_shuffler, step)

254 255 256
                    if self.analizer is not None:
                        self.validation_summary_writter.add_summary(self.analizer(
                            validation_data_shuffler, self.architecture, session), step)
257

258 259 260 261 262
            logger.info("Training finally finished")

            self.train_summary_writter.close()
            if validation_data_shuffler is not None:
                self.validation_summary_writter.close()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
263

264 265 266
            self.architecture.save(hdf5)
            del hdf5

267 268 269 270
            if self.prefetch:
                # now they should definetely stop
                self.thread_pool.request_stop()
                self.thread_pool.join(threads)