Trainer.py 13.1 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8 9 10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler.OnlineSampling import OnLineSampling
16
from .learning_rate import constant
17

18
logger = bob.core.log.setup("bob.learn.tensorflow")
19

20

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
      architecture: The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`
      optimizer: One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html
      use_gpu: Use GPUs in the training
      loss: Loss
      temp_dir: The output directory

      base_learning_rate: Initial learning rate
      weight_decay:
      convergence_threshold:

      iterations: Maximum number of iterations
      snapshot: Will take a snapshot of the network at every `n` iterations
      prefetch: Use extra Threads to deal with the I/O
      analizer: Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`
      verbosity_level:

    """
44
    def __init__(self,
45 46
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
47 48
                 use_gpu=False,
                 loss=None,
49
                 temp_dir="cnn",
50

51
                 # Learning rate
52
                 learning_rate=constant(),
53

54
                 ###### training options ##########
55
                 convergence_threshold=0.01,
56
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
57 58
                 snapshot=500,
                 validation_snapshot=100,
59
                 prefetch=False,
60 61

                 ## Analizer
62
                 analizer=SoftmaxAnalizer(),
63

64 65 66
                 ### Pretrained model
                 model_from_file="",

67
                 verbosity_level=2):
68

69 70
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
71 72

        self.architecture = architecture
73
        self.optimizer_class = optimizer
74
        self.use_gpu = use_gpu
75 76 77
        self.loss = loss
        self.temp_dir = temp_dir

78 79 80 81
        #self.base_learning_rate = base_learning_rate
        self.learning_rate = learning_rate
        #self.weight_decay = weight_decay
        #self.decay_steps = decay_steps
82 83 84

        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
85
        self.validation_snapshot = validation_snapshot
86
        self.convergence_threshold = convergence_threshold
87
        self.prefetch = prefetch
88

89 90 91 92 93 94 95
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
96
        self.thread_pool = None
97 98 99 100 101

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

102 103 104 105 106
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
107
        self.global_step = None
108

109 110
        self.model_from_file = model_from_file

111 112
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
113 114 115
    def __del__(self):
        tf.reset_default_graph()

116
    def compute_graph(self, data_shuffler, prefetch=False, name=""):
117
        """
118 119
        Computes the graph for the trainer.

120

121 122 123
        ** Parameters **

            data_shuffler: Data shuffler
124
            prefetch:
125 126 127 128
            name: Name of the graph
        """

        # Defining place holders
129
        if prefetch:
130
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
131 132 133 134 135 136 137

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
138
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
139 140 141 142 143 144 145
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
146
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
147 148 149 150 151 152 153 154 155

        # Creating graphs and defining the loss
        network_graph = self.architecture.compute_graph(feature_batch)
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
156
        Given a data shuffler prepared the dictionary to be injected in the graph
157 158 159 160

        ** Parameters **
            data_shuffler:

161
        """
162 163
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
164 165 166 167 168

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

169 170 171 172 173 174 175 176 177 178
    def fit(self, session, step):
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

179
        if self.prefetch:
180 181
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train])
182 183 184 185 186
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train], feed_dict=feed_dict)

187 188
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
189

190
    def compute_validation(self,  session, data_shuffler, step):
191 192 193 194 195 196 197 198 199
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
200
        # Opening a new session for validation
201 202 203 204
        self.validation_graph = self.compute_graph(data_shuffler, name="validation")
        feed_dict = self.get_feed_dict(data_shuffler)
        l = session.run(self.validation_graph, feed_dict=feed_dict)

205 206 207
        if self.validation_summary_writter is None:
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph)

208 209 210 211 212
        summaries = []
        summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

213 214 215 216 217
    def create_general_summary(self):
        """
        Creates a simple tensorboard summary with the value of the loss and learning rate
        """

218 219 220 221 222
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

223
    def start_thread(self, session):
224 225 226 227 228 229 230
        """
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
        """

231
        threads = []
232 233
        for n in range(3):
            t = threading.Thread(target=self.load_and_enqueue, args=(session,))
234 235 236 237
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
238

239 240
    def load_and_enqueue(self, session):
        """
241
        Injecting data in the place holder queue
242 243 244

        **Parameters**
          session: Tensorflow session
245
        """
246

247
        while not self.thread_pool.should_stop():
248 249
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
250

251 252 253
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

254
            session.run(self.enqueue_op, feed_dict=feed_dict)
255 256 257

    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
258
        Train the network
259 260 261 262 263
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
264

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
265 266 267
        # Pickle the architecture to save
        self.architecture.pickle_net(train_data_shuffler.deployment_shape)

268
        # TODO: find an elegant way to provide this as a parameter of the trainer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
269
        self.global_step = tf.Variable(0, trainable=False)
270 271 272 273 274 275 276 277
        #self.learning_rate = tf.Variable(self.base_learning_rate)
        #self.learning_rate = tf.train.exponential_decay(
        #    learning_rate=self.base_learning_rate,  # Learning rate
        #    global_step=self.global_step,
        #    decay_steps=self.decay_steps,
        #    decay_rate=self.weight_decay,  # Decay step
        #    staircase=False
        #)
278
        self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
279

280
        # Preparing the optimizer
281
        self.optimizer_class._learning_rate = self.learning_rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
282 283
        self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)

284
        # Train summary
285
        self.summaries_train = self.create_general_summary()
286 287

        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
288

289 290
        config = tf.ConfigProto(log_device_placement=True)
        config.gpu_options.allow_growth = True
291

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
292
        with tf.Session(config=config) as session:
293
            tf.initialize_all_variables().run()
294

295 296 297
            # Original tensorflow saver object
            saver = tf.train.Saver(var_list=tf.trainable_variables())

298 299 300 301 302 303
            # Loading a pretrained model
            if self.model_from_file != "":
                logger.info("Loading pretrained model from {0}".format(self.model_from_file))
                hdf5 = bob.io.base.HDF5File(self.model_from_file)
                self.architecture.load_variables_only(hdf5, session)

304 305 306
            if isinstance(train_data_shuffler, OnLineSampling):
                train_data_shuffler.set_feature_extractor(self.architecture, session=session)

307
            # Start a thread to enqueue data asynchronously, and hide I/O latency.
308 309 310 311
            if self.prefetch:
                self.thread_pool = tf.train.Coordinator()
                tf.train.start_queue_runners(coord=self.thread_pool)
                threads = self.start_thread(session)
312

313
            # TENSOR BOARD SUMMARY
314
            self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph)
315
            for step in range(self.iterations):
316 317 318 319 320 321 322

                start = time.time()
                self.fit(session, step)
                end = time.time()
                summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
                self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
323 324
                # Running validation
                if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
325
                    self.compute_validation(session, validation_data_shuffler, step)
326

327 328
                    if self.analizer is not None:
                        self.validation_summary_writter.add_summary(self.analizer(
329
                             validation_data_shuffler, self.architecture, session), step)
330

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
331 332 333
                # Taking snapshot
                if step % self.snapshot == 0:
                    logger.info("Taking snapshot")
334 335 336 337
                    path = os.path.join(self.temp_dir, 'model_snapshot{0}.hdf5'.format(step))
                    #path_original = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                    #self.architecture.save_original(session, saver, path_original)
                    hdf5 = bob.io.base.HDF5File(path, 'w')
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
338 339 340
                    self.architecture.save(hdf5)
                    del hdf5

341 342 343 344 345
            logger.info("Training finally finished")

            self.train_summary_writter.close()
            if validation_data_shuffler is not None:
                self.validation_summary_writter.close()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
346

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
347
            # Saving the final network
348 349 350 351
            path = os.path.join(self.temp_dir, 'model.hdf5')
            #path_original = os.path.join(self.temp_dir, 'model.ckp')
            #self.architecture.save_original(session, saver, path_original)
            hdf5 = bob.io.base.HDF5File(path, 'w')
352 353 354
            self.architecture.save(hdf5)
            del hdf5

355 356 357 358
            if self.prefetch:
                # now they should definetely stop
                self.thread_pool.request_stop()
                self.thread_pool.join(threads)