Trainer.py 18.7 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
7 8 9
import threading
import os
import bob.io.base
10
import bob.core
11
from ..analyzers import SoftmaxAnalizer
12
from tensorflow.core.framework import summary_pb2
13
import time
14
from bob.learn.tensorflow.datashuffler import OnlineSampling, TFRecord
15
from bob.learn.tensorflow.utils.session import Session
16
from bob.learn.tensorflow.utils import compute_embedding_accuracy
17
from .learning_rate import constant
18
import time
19

20 21 22 23 24
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

25

26 27 28 29 30 31
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
32

Tiago Pereira's avatar
Tiago Pereira committed
33 34
    train_data_shuffler:
      The data shuffler used for batching data for training
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
35

Tiago Pereira's avatar
Tiago Pereira committed
36
    iterations:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
37
      Maximum number of iterations
38

Tiago Pereira's avatar
Tiago Pereira committed
39
    snapshot:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
40
      Will take a snapshot of the network at every `n` iterations
41

Tiago Pereira's avatar
Tiago Pereira committed
42 43
    validation_snapshot:
      Test with validation each `n` iterations
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
44 45 46 47

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

Tiago Pereira's avatar
Tiago Pereira committed
48 49 50
    temp_dir: str
      The output directory

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
51
    verbosity_level:
52 53

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
54

55
    def __init__(self,
Tiago Pereira's avatar
Tiago Pereira committed
56
                 train_data_shuffler,
57
                 validation_data_shuffler=None,
58
                 validate_with_embeddings=False,
59

60 61
                 ###### training options ##########
                 iterations=5000,
62
                 snapshot=1000,
63
                 validation_snapshot=2000,#2000,
64
                 keep_checkpoint_every_n_hours=2,
65 66

                 ## Analizer
67
                 analizer=SoftmaxAnalizer(),
68

Tiago Pereira's avatar
Tiago Pereira committed
69 70
                 # Temporatu dir
                 temp_dir="cnn",
71

72
                 verbosity_level=2):
73

Tiago Pereira's avatar
Tiago Pereira committed
74
        self.train_data_shuffler = train_data_shuffler
75

76 77
        self.temp_dir = temp_dir

78 79
        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
80
        self.validation_snapshot = validation_snapshot
81
        self.keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
82

83 84 85
        # Training variables used in the fit
        self.summaries_train = None
        self.train_summary_writter = None
86
        self.thread_pool = None
87 88 89

        # Validation data
        self.validation_summary_writter = None
90
        self.summaries_validation = None
91
        self.validation_data_shuffler = validation_data_shuffler
92

93 94
        # Analizer
        self.analizer = analizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
95
        self.global_step = None
96

97
        self.session = None
98

Tiago Pereira's avatar
Tiago Pereira committed
99
        self.graph = None
100 101
        self.validation_graph = None
                
Tiago Pereira's avatar
Tiago Pereira committed
102
        self.loss = None
103
        
Tiago Pereira's avatar
Tiago Pereira committed
104
        self.predictor = None
105 106
        self.validation_predictor = None  
        self.validate_with_embeddings = validate_with_embeddings      
107
        
Tiago Pereira's avatar
Tiago Pereira committed
108 109
        self.optimizer_class = None
        self.learning_rate = None
110

Tiago Pereira's avatar
Tiago Pereira committed
111 112
        # Training variables used in the fit
        self.optimizer = None
113
        
Tiago Pereira's avatar
Tiago Pereira committed
114 115
        self.data_ph = None
        self.label_ph = None
116 117 118 119
        
        self.validation_data_ph = None
        self.validation_label_ph = None
        
Tiago Pereira's avatar
Tiago Pereira committed
120 121
        self.saver = None

122 123
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago Pereira's avatar
Tiago Pereira committed
124 125 126
        # Creating the session
        self.session = Session.instance(new=True).session
        self.from_scratch = True
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
        
    def train(self):
        """
        Train the network        
        Here we basically have the loop for that takes your graph and do a sequence of session.run
        """

        # Creating directories
        bob.io.base.create_directories_safe(self.temp_dir)
        logger.info("Initializing !!")

        # Loading a pretrained model
        if self.from_scratch:
            start_step = 0
        else:
            start_step = self.global_step.eval(session=self.session)

        # TODO: Put this back as soon as possible
        #if isinstance(train_data_shuffler, OnlineSampling):
        #    train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)

        # Start a thread to enqueue data asynchronously, and hide I/O latency.        
        if self.train_data_shuffler.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            # In case you have your own queue
            if not isinstance(self.train_data_shuffler, TFRecord):
                threads = self.start_thread()

        # Bootstrapping the summary writters
        self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
        if self.validation_data_shuffler is not None:
            self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'),
                                                                    self.session.graph)

        ######################### Loop for #################
        for step in range(start_step, start_step+self.iterations):
            # Run fit in the graph
            start = time.time()
            self.fit(step)
            end = time.time()

            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
            if self.validation_data_shuffler is not None and step % self.validation_snapshot == 0:
174 175 176 177
                if self.validate_with_embeddings:
                    self.compute_validation_embeddings(step)
                else:
                    self.compute_validation(step)
178 179 180 181 182 183 184 185 186

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                self.saver.save(self.session, path, global_step=step)

        # Running validation for the last time
        if self.validation_data_shuffler is not None:
187 188 189 190 191
            if self.validate_with_embeddings:
                self.compute_validation_embeddings(step)
            else:
                self.compute_validation(step)
            
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
            
        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if self.validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
        self.saver.save(self.session, path)

        if self.train_data_shuffler.prefetch or isinstance(self.train_data_shuffler, TFRecord):
            # now they should definetely stop
            self.thread_pool.request_stop()
            #if not isinstance(self.train_data_shuffler, TFRecord):
            #    self.thread_pool.join(threads)        

Tiago Pereira's avatar
Tiago Pereira committed
209 210
    def create_network_from_scratch(self,
                                    graph,
211
                                    validation_graph=None,
Tiago Pereira's avatar
Tiago Pereira committed
212 213
                                    optimizer=tf.train.AdamOptimizer(),
                                    loss=None,
214

Tiago Pereira's avatar
Tiago Pereira committed
215 216 217 218
                                    # Learning rate
                                    learning_rate=None,
                                    ):

Tiago Pereira's avatar
Tiago Pereira committed
219 220
        """
        Prepare all the tensorflow variables before training.
221
        
Tiago Pereira's avatar
Tiago Pereira committed
222
        **Parameters**
223

Tiago Pereira's avatar
Tiago Pereira committed
224
            graph: Input graph for training
225

Tiago Pereira's avatar
Tiago Pereira committed
226
            optimizer: Solver
227

Tiago Pereira's avatar
Tiago Pereira committed
228
            loss: Loss function
229

Tiago Pereira's avatar
Tiago Pereira committed
230 231 232
            learning_rate: Learning rate
        """

233
        # Getting the pointer to the placeholders
234 235
        self.data_ph = self.train_data_shuffler("data", from_queue=True)
        self.label_ph = self.train_data_shuffler("label", from_queue=True)
236
                
Tiago Pereira's avatar
Tiago Pereira committed
237
        self.graph = graph
238
        self.loss = loss        
239

240 241 242
        # Attaching the loss in the graph
        self.predictor = self.loss(self.graph, self.label_ph)
        
Tiago Pereira's avatar
Tiago Pereira committed
243 244
        self.optimizer_class = optimizer
        self.learning_rate = learning_rate
245
        self.global_step = tf.contrib.framework.get_or_create_global_step()
Tiago Pereira's avatar
Tiago Pereira committed
246

247 248 249 250
        # Preparing the optimizer
        self.optimizer_class._learning_rate = self.learning_rate
        self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)

Tiago Pereira's avatar
Tiago Pereira committed
251
        # Saving all the variables
252 253
        self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(), 
                                    keep_checkpoint_every_n_hours=self.keep_checkpoint_every_n_hours)
Tiago Pereira's avatar
Tiago Pereira committed
254

255
        self.summaries_train = self.create_general_summary(self.predictor, self.graph, self.label_ph)
256

257 258
        # SAving some variables
        tf.add_to_collection("global_step", self.global_step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
259 260 261 262 263 264 265

        if isinstance(self.graph, dict):
            tf.add_to_collection("graph", self.graph['logits'])
            tf.add_to_collection("prelogits", self.graph['prelogits'])
        else:
            tf.add_to_collection("graph", self.graph)
        
Tiago Pereira's avatar
Tiago Pereira committed
266
        tf.add_to_collection("predictor", self.predictor)
267

Tiago Pereira's avatar
Tiago Pereira committed
268 269
        tf.add_to_collection("data_ph", self.data_ph)
        tf.add_to_collection("label_ph", self.label_ph)
270

Tiago Pereira's avatar
Tiago Pereira committed
271 272
        tf.add_to_collection("optimizer", self.optimizer)
        tf.add_to_collection("learning_rate", self.learning_rate)
273

Tiago Pereira's avatar
Tiago Pereira committed
274
        tf.add_to_collection("summaries_train", self.summaries_train)
275

276
        # Same business with the validation
277
        if self.validation_data_shuffler is not None:
278 279 280 281 282
            self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True)
            self.validation_label_ph = self.validation_data_shuffler("label", from_queue=True)

            self.validation_graph = validation_graph

283 284 285 286
            if self.validate_with_embeddings:
                self.validation_predictor = self.validation_graph
            else:
                self.validation_predictor = self.loss(self.validation_graph, self.validation_label_ph)
287 288 289 290 291 292 293 294 295 296

            self.summaries_validation = self.create_general_summary(self.validation_predictor, self.validation_graph, self.validation_label_ph)
            tf.add_to_collection("summaries_validation", self.summaries_validation)
            
            tf.add_to_collection("validation_graph", self.validation_graph)
            tf.add_to_collection("validation_data_ph", self.validation_data_ph)
            tf.add_to_collection("validation_label_ph", self.validation_label_ph)

            tf.add_to_collection("validation_predictor", self.validation_predictor)
            tf.add_to_collection("summaries_validation", self.summaries_validation)
Tiago Pereira's avatar
Tiago Pereira committed
297

Tiago Pereira's avatar
Tiago Pereira committed
298
        # Creating the variables
299
        tf.local_variables_initializer().run(session=self.session)
Tiago Pereira's avatar
Tiago Pereira committed
300 301
        tf.global_variables_initializer().run(session=self.session)

302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
    def load_checkpoint(self, file_name, clear_devices=True):
        """
        Load a checkpoint

        ** Parameters **

           file_name:
                Name of the metafile to be loaded.
                If a directory is passed, the last checkpoint will be loaded

        """
        if os.path.isdir(file_name):
            checkpoint_path = tf.train.get_checkpoint_state(file_name).model_checkpoint_path
            self.saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=clear_devices)
            self.saver.restore(self.session, tf.train.latest_checkpoint(file_name))
        else:
            self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
            self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
            
    def load_variables_from_external_model(self, file_name, var_list):
        """
        Load a set of variables from a given model and update them in the current one
        
        ** Parameters **
        
          file_name:
            Name of the tensorflow model to be loaded
          var_list:
            List of variables to be loaded. A tensorflow exception will be raised in case the variable does not exists
        
        """
        
        assert len(var_list)>0
        
        tf_varlist = []
        for v in var_list:
            tf_varlist += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=v)

        saver = tf.train.Saver(tf_varlist)
        saver.restore(self.session, file_name)
342

343
    def create_network_from_file(self, file_name, clear_devices=True):
Tiago Pereira's avatar
Tiago Pereira committed
344
        """
Tiago Pereira's avatar
Tiago Pereira committed
345
        Bootstrap a graph from a checkpoint
Tiago Pereira's avatar
Tiago Pereira committed
346 347 348

         ** Parameters **

Tiago Pereira's avatar
Tiago Pereira committed
349
           file_name: Name of of the checkpoing
Tiago Pereira's avatar
Tiago Pereira committed
350
        """
351 352 353

        logger.info("Loading last checkpoint !!")
        self.load_checkpoint(file_name, clear_devices=True)
Tiago Pereira's avatar
Tiago Pereira committed
354 355

        # Loading training graph
Tiago Pereira's avatar
Tiago Pereira committed
356 357
        self.data_ph = tf.get_collection("data_ph")[0]
        self.label_ph = tf.get_collection("label_ph")[0]
Tiago Pereira's avatar
Tiago Pereira committed
358 359 360 361 362 363 364

        self.graph = tf.get_collection("graph")[0]
        self.predictor = tf.get_collection("predictor")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
365
        self.summaries_train = tf.get_collection("summaries_train")[0]        
Tiago Pereira's avatar
Tiago Pereira committed
366 367
        self.global_step = tf.get_collection("global_step")[0]
        self.from_scratch = False
368 369
        
        # Loading the validation bits
370
        if self.validation_data_shuffler is not None:
371 372 373 374 375 376 377 378 379
            self.summaries_validation = tf.get_collection("summaries_validation")[0]

            self.validation_graph = tf.get_collection("validation_graph")[0]
            self.validation_data_ph = tf.get_collection("validation_data_ph")[0]
            self.validation_label = tf.get_collection("validation_label_ph")[0]

            self.validation_predictor = tf.get_collection("validation_predictor")[0]
            self.summaries_validation = tf.get_collection("summaries_validation")[0]

Tiago Pereira's avatar
Tiago Pereira committed
380 381
    def __del__(self):
        tf.reset_default_graph()
382 383 384

    def get_feed_dict(self, data_shuffler):
        """
385
        Given a data shuffler prepared the dictionary to be injected in the graph
386 387

        ** Parameters **
388 389

            data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base`
390

391
        """
392
        [data, labels] = data_shuffler.get_batch()
393

Tiago Pereira's avatar
Tiago Pereira committed
394 395
        feed_dict = {self.data_ph: data,
                     self.label_ph: labels}
396 397
        return feed_dict

398
    def fit(self, step):
399 400 401 402 403 404 405 406 407
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

408 409
        if self.train_data_shuffler.prefetch:
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
410
                                                  self.learning_rate, self.summaries_train])
411 412
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
413
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
414
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
415

416 417
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
418

419
    def compute_validation(self, step):
Tiago Pereira's avatar
Tiago Pereira committed
420 421 422 423 424 425 426 427 428 429
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """

430 431 432 433 434 435 436 437
        if self.validation_data_shuffler.prefetch:
            l, lr, summary = self.session.run([self.validation_predictor,
                                               self.learning_rate, self.summaries_validation])
        else:
            feed_dict = self.get_feed_dict(self.validation_data_shuffler)
            l, lr, summary = self.session.run([self.validation_predictor,
                                               self.learning_rate, self.summaries_validation],
                                               feed_dict=feed_dict)
Tiago Pereira's avatar
Tiago Pereira committed
438

439 440
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
        self.validation_summary_writter.add_summary(summary, step)               
Tiago Pereira's avatar
Tiago Pereira committed
441

442 443 444 445 446 447 448 449 450 451
    def compute_validation_embeddings(self, step):
        """
        Computes the loss in the validation set with embeddings

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
452
        
453 454 455 456 457 458 459 460 461 462 463 464 465 466
        if self.validation_data_shuffler.prefetch:
            embedding, labels = self.session.run([self.validation_predictor, self.validation_label_ph])
        else:
            feed_dict = self.get_feed_dict(self.validation_data_shuffler)
            embedding, labels = self.session.run([self.validation_predictor, self.validation_label_ph],
                                               feed_dict=feed_dict)
                                               
        accuracy = compute_embedding_accuracy(embedding, labels)
        
        summary = summary_pb2.Summary.Value(tag="accuracy", simple_value=accuracy)
        logger.info("VALIDATION Accuracy set step={0} = {1}".format(step, accuracy))
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)               


467
    def create_general_summary(self, average_loss, output, label):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
468
        """
469
        Creates a simple tensorboard summary with the value of the loss and learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
470
        """
471
        # Train summary
472
        tf.summary.scalar('loss', average_loss)
473
        tf.summary.scalar('lr', self.learning_rate)        
474 475

        # Computing accuracy
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
476 477 478 479 480
        if isinstance(output, dict):
            correct_prediction = tf.equal(tf.argmax(output['logits'], 1), label)
        else:
            correct_prediction = tf.equal(tf.argmax(output, 1), label)
        
481 482
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar('accuracy', accuracy)        
483
        return tf.summary.merge_all()
484

485
    def start_thread(self):
Tiago Pereira's avatar
Tiago Pereira committed
486
        """
487 488 489 490
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
Tiago Pereira's avatar
Tiago Pereira committed
491
        """
492

493
        threads = []
494
        for n in range(self.train_data_shuffler.prefetch_threads):
495
            t = threading.Thread(target=self.load_and_enqueue, args=())
496 497 498 499
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
500

501
    def load_and_enqueue(self):
Tiago Pereira's avatar
Tiago Pereira committed
502
        """
503
        Injecting data in the place holder queue
504 505 506

        **Parameters**
          session: Tensorflow session
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
507

Tiago Pereira's avatar
Tiago Pereira committed
508
        """
509
        while not self.thread_pool.should_stop():
510
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
511

512 513
            data_ph = self.train_data_shuffler("data", from_queue=False)
            label_ph = self.train_data_shuffler("label", from_queue=False)
514

515 516 517 518
            feed_dict = {data_ph: train_data,
                         label_ph: train_labels}

            self.session.run(self.train_data_shuffler.enqueue_op, feed_dict=feed_dict)
519

520