Trainer.py 16.3 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8 9 10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler.OnlineSampling import OnLineSampling
16
from bob.learn.tensorflow.utils.session import Session
17
from .learning_rate import constant
18

19
logger = bob.core.log.setup("bob.learn.tensorflow")
20

21

22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
      architecture: The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`
      optimizer: One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html
      use_gpu: Use GPUs in the training
      loss: Loss
      temp_dir: The output directory

      base_learning_rate: Initial learning rate
      weight_decay:
      convergence_threshold:

      iterations: Maximum number of iterations
      snapshot: Will take a snapshot of the network at every `n` iterations
      prefetch: Use extra Threads to deal with the I/O
      analizer: Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`
      verbosity_level:

    """
45
    def __init__(self,
46 47
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
48 49
                 use_gpu=False,
                 loss=None,
50
                 temp_dir="cnn",
51

52
                 # Learning rate
53
                 learning_rate=constant(),
54

55
                 ###### training options ##########
56
                 convergence_threshold=0.01,
57
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
58 59
                 snapshot=500,
                 validation_snapshot=100,
60
                 prefetch=False,
61 62

                 ## Analizer
63
                 analizer=SoftmaxAnalizer(),
64

65 66 67
                 ### Pretrained model
                 model_from_file="",

68
                 verbosity_level=2):
69

70 71
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
72 73

        self.architecture = architecture
74
        self.optimizer_class = optimizer
75
        self.use_gpu = use_gpu
76 77 78
        self.loss = loss
        self.temp_dir = temp_dir

79
        self.learning_rate = learning_rate
80 81 82

        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
83
        self.validation_snapshot = validation_snapshot
84
        self.convergence_threshold = convergence_threshold
85
        self.prefetch = prefetch
86

87 88 89 90 91 92
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
93
        self.thread_pool = None
94 95 96 97 98

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

99 100 101 102 103
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
104
        self.global_step = None
105

106
        self.model_from_file = model_from_file
107
        self.session = None
108

109 110
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
111 112 113
    def __del__(self):
        tf.reset_default_graph()

114
    def compute_graph(self, data_shuffler, prefetch=False, name="", training=True):
115
        """
116 117 118 119 120
        Computes the graph for the trainer.

        ** Parameters **

            data_shuffler: Data shuffler
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
121
            prefetch: Uses prefetch
122
            name: Name of the graph
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
123
            training: Is it a training graph?
124 125 126
        """

        # Defining place holders
127
        if prefetch:
128
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
129 130 131 132 133 134 135

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
136
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
137 138 139 140 141 142 143
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
144
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
145 146

        # Creating graphs and defining the loss
147
        network_graph = self.architecture.compute_graph(feature_batch, training=training)
148 149 150 151 152 153
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
154
        Given a data shuffler prepared the dictionary to be injected in the graph
155 156 157 158

        ** Parameters **
            data_shuffler:

159
        """
160 161
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
162 163 164 165 166

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

167
    def fit(self, step):
168 169 170 171 172 173 174 175 176
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

177
        if self.prefetch:
178
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
179
                                             self.learning_rate, self.summaries_train])
180 181
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
182
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
183 184
                                             self.learning_rate, self.summaries_train], feed_dict=feed_dict)

185 186
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
187

188
    def compute_validation(self, data_shuffler, step):
189 190 191 192 193 194 195 196 197
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
198
        # Opening a new session for validation
199
        feed_dict = self.get_feed_dict(data_shuffler)
200
        l = self.session.run(self.validation_graph, feed_dict=feed_dict)
201

202
        if self.validation_summary_writter is None:
203
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), self.session.graph)
204

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
205
        summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
206 207 208
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

209 210 211 212
    def create_general_summary(self):
        """
        Creates a simple tensorboard summary with the value of the loss and learning rate
        """
213 214 215 216 217
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

218
    def start_thread(self):
219 220 221 222 223 224 225
        """
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
        """

226
        threads = []
227
        for n in range(3):
228
            t = threading.Thread(target=self.load_and_enqueue, args=())
229 230 231 232
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
233

234
    def load_and_enqueue(self):
235
        """
236
        Injecting data in the place holder queue
237 238 239

        **Parameters**
          session: Tensorflow session
240
        """
241

242
        while not self.thread_pool.should_stop():
243 244
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
245

246 247 248
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

249
            self.session.run(self.enqueue_op, feed_dict=feed_dict)
250

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
251
    def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
252
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
253
        Create all the necessary graphs for training, validation and inference graphs
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
254
        """
255 256 257 258 259 260 261 262 263 264 265

        # Creating train graph
        self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
        tf.add_to_collection("training_graph", self.training_graph)

        # Creating inference graph
        self.architecture.compute_inference_placeholder(train_data_shuffler.deployment_shape)
        self.architecture.compute_inference_graph()
        tf.add_to_collection("inference_placeholder", self.architecture.inference_placeholder)
        tf.add_to_collection("inference_graph", self.architecture.inference_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
266
        # Creating validation graph
267 268 269 270
        if validation_data_shuffler is not None:
            self.validation_graph = self.compute_graph(validation_data_shuffler, name="validation", training=False)
            tf.add_to_collection("validation_graph", self.validation_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
271 272 273 274 275
        self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)

    def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
        """
        Persist the placeholders
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
276 277 278 279 280

         ** Parameters **
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
        """

        # Persisting the placeholders
        if self.prefetch:
            batch, label = train_data_shuffler.get_placeholders_forprefetch()
        else:
            batch, label = train_data_shuffler.get_placeholders()

        tf.add_to_collection("train_placeholder_data", batch)
        tf.add_to_collection("train_placeholder_label", label)

        # Creating validation graph
        if validation_data_shuffler is not None:
            batch, label = validation_data_shuffler.get_placeholders()
            tf.add_to_collection("validation_placeholder_data", batch)
            tf.add_to_collection("validation_placeholder_label", label)

298
    def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
299 300
        """
        Bootstrap all the necessary data from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
301 302 303 304 305 306 307

         ** Parameters **
           session: Tensorflow session
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
308
        """
309
        saver = self.architecture.load(self.session, self.model_from_file)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328

        # Loading training graph
        self.training_graph = tf.get_collection("training_graph")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]

        if validation_data_shuffler is not None:
            self.validation_graph = tf.get_collection("validation_graph")[0]

        self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)

        return saver

    def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
        """
        Load placeholders from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
329 330 331 332 333 334

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
335 336 337 338 339 340 341 342 343
        """

        train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
                                             tf.get_collection("train_placeholder_label")[0])

        if validation_data_shuffler is not None:
            train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
                                                 tf.get_collection("validation_placeholder_label")[0])

344 345
    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
346 347 348 349 350 351
        Train the network:

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation
352 353 354 355 356
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
357

358
        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
359

360 361
        config = tf.ConfigProto(log_device_placement=True,
                                gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.333))
362
        config.gpu_options.allow_growth = True
363

364 365
        # Pickle the architecture to save
        self.architecture.pickle_net(train_data_shuffler.deployment_shape)
366

367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
        #with tf.Session(config=config) as session:

        self.session = Session.instance().session

        # Loading a pretrained model
        if self.model_from_file != "":
            logger.info("Loading pretrained model from {0}".format(self.model_from_file))
            saver = self.bootstrap_graphs_fromfile(self.session, train_data_shuffler, validation_data_shuffler)
        else:
            # Bootstraping all the graphs
            self.bootstrap_graphs(train_data_shuffler, validation_data_shuffler)

            # TODO: find an elegant way to provide this as a parameter of the trainer
            self.global_step = tf.Variable(0, trainable=False)

            # Preparing the optimizer
            self.optimizer_class._learning_rate = self.learning_rate
            self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
            tf.add_to_collection("optimizer", self.optimizer)
            tf.add_to_collection("learning_rate", self.learning_rate)

            # Train summary
            self.summaries_train = self.create_general_summary()
            tf.add_to_collection("summaries_train", self.summaries_train)

            tf.initialize_all_variables().run(session=self.session)

            # Original tensorflow saver object
            saver = tf.train.Saver(var_list=tf.all_variables())

        if isinstance(train_data_shuffler, OnLineSampling):
            train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)

        # Start a thread to enqueue data asynchronously, and hide I/O latency.
        if self.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool)
            threads = self.start_thread(self.session)

        # TENSOR BOARD SUMMARY
        self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
        for step in range(self.iterations):

            start = time.time()
            self.fit(self.session, step)
            end = time.time()
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
            if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
                self.compute_validation(self.session, validation_data_shuffler, step)

                if self.analizer is not None:
                    self.validation_summary_writter.add_summary(self.analizer(
                         validation_data_shuffler, self.architecture, self.session), step)

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                self.architecture.save(saver, path)

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
        self.architecture.save(saver, path)

        if self.prefetch:
            # now they should definetely stop
            self.thread_pool.request_stop()
            self.thread_pool.join(threads)