Trainer.py 15.9 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
7 8 9
import threading
import os
import bob.io.base
10
import bob.core
11
from ..analyzers import SoftmaxAnalizer
12
from tensorflow.core.framework import summary_pb2
13
import time
14
from bob.learn.tensorflow.datashuffler import OnlineSampling, TFRecord
15
from bob.learn.tensorflow.utils.session import Session
16
from .learning_rate import constant
17
import time
18

19 20 21 22 23
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

24

25 26 27 28 29 30
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
31

Tiago Pereira's avatar
Tiago Pereira committed
32 33
    train_data_shuffler:
      The data shuffler used for batching data for training
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
34

Tiago Pereira's avatar
Tiago Pereira committed
35
    iterations:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
36
      Maximum number of iterations
37

Tiago Pereira's avatar
Tiago Pereira committed
38
    snapshot:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
39
      Will take a snapshot of the network at every `n` iterations
40

Tiago Pereira's avatar
Tiago Pereira committed
41 42
    validation_snapshot:
      Test with validation each `n` iterations
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
43 44 45 46

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

Tiago Pereira's avatar
Tiago Pereira committed
47 48 49
    temp_dir: str
      The output directory

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
50
    verbosity_level:
51 52

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
53

54
    def __init__(self,
Tiago Pereira's avatar
Tiago Pereira committed
55
                 train_data_shuffler,
56
                 validation_data_shuffler=None,
57

58 59
                 ###### training options ##########
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
60 61
                 snapshot=500,
                 validation_snapshot=100,
62
                 keep_checkpoint_every_n_hours=2,
63 64

                 ## Analizer
65
                 analizer=SoftmaxAnalizer(),
66

Tiago Pereira's avatar
Tiago Pereira committed
67 68
                 # Temporatu dir
                 temp_dir="cnn",
69

70
                 verbosity_level=2):
71

Tiago Pereira's avatar
Tiago Pereira committed
72
        self.train_data_shuffler = train_data_shuffler
73

74 75
        self.temp_dir = temp_dir

76 77
        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
78
        self.validation_snapshot = validation_snapshot
79
        self.keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
80

81 82 83
        # Training variables used in the fit
        self.summaries_train = None
        self.train_summary_writter = None
84
        self.thread_pool = None
85 86 87

        # Validation data
        self.validation_summary_writter = None
88
        self.summaries_validation = None
89
        self.validation_data_shuffler = validation_data_shuffler
90

91 92
        # Analizer
        self.analizer = analizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
93
        self.global_step = None
94

95
        self.session = None
96

Tiago Pereira's avatar
Tiago Pereira committed
97
        self.graph = None
98 99
        self.validation_graph = None
                
Tiago Pereira's avatar
Tiago Pereira committed
100
        self.loss = None
101
        
Tiago Pereira's avatar
Tiago Pereira committed
102
        self.predictor = None
103 104
        self.validation_predictor = None        
        
Tiago Pereira's avatar
Tiago Pereira committed
105 106
        self.optimizer_class = None
        self.learning_rate = None
107

Tiago Pereira's avatar
Tiago Pereira committed
108 109
        # Training variables used in the fit
        self.optimizer = None
110
        
Tiago Pereira's avatar
Tiago Pereira committed
111 112
        self.data_ph = None
        self.label_ph = None
113 114 115 116
        
        self.validation_data_ph = None
        self.validation_label_ph = None
        
Tiago Pereira's avatar
Tiago Pereira committed
117 118
        self.saver = None

119 120
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago Pereira's avatar
Tiago Pereira committed
121 122 123
        # Creating the session
        self.session = Session.instance(new=True).session
        self.from_scratch = True
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
        
    def train(self):
        """
        Train the network        
        Here we basically have the loop for that takes your graph and do a sequence of session.run
        """

        # Creating directories
        bob.io.base.create_directories_safe(self.temp_dir)
        logger.info("Initializing !!")

        # Loading a pretrained model
        if self.from_scratch:
            start_step = 0
        else:
            start_step = self.global_step.eval(session=self.session)

        # TODO: Put this back as soon as possible
        #if isinstance(train_data_shuffler, OnlineSampling):
        #    train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)

        # Start a thread to enqueue data asynchronously, and hide I/O latency.        
        if self.train_data_shuffler.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            # In case you have your own queue
            if not isinstance(self.train_data_shuffler, TFRecord):
                threads = self.start_thread()

        # Bootstrapping the summary writters
        self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
        if self.validation_data_shuffler is not None:
            self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'),
                                                                    self.session.graph)

        ######################### Loop for #################
        for step in range(start_step, start_step+self.iterations):
            # Run fit in the graph
            start = time.time()
            self.fit(step)
            end = time.time()

            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
            if self.validation_data_shuffler is not None and step % self.validation_snapshot == 0:
                self.compute_validation(step)

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                self.saver.save(self.session, path, global_step=step)

        # Running validation for the last time
        if self.validation_data_shuffler is not None:
            self.compute_validation(step)
            
        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if self.validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
        self.saver.save(self.session, path)

        if self.train_data_shuffler.prefetch or isinstance(self.train_data_shuffler, TFRecord):
            # now they should definetely stop
            self.thread_pool.request_stop()
            #if not isinstance(self.train_data_shuffler, TFRecord):
            #    self.thread_pool.join(threads)        

Tiago Pereira's avatar
Tiago Pereira committed
199 200
    def create_network_from_scratch(self,
                                    graph,
201
                                    validation_graph=None,
Tiago Pereira's avatar
Tiago Pereira committed
202 203
                                    optimizer=tf.train.AdamOptimizer(),
                                    loss=None,
204

Tiago Pereira's avatar
Tiago Pereira committed
205 206 207 208
                                    # Learning rate
                                    learning_rate=None,
                                    ):

Tiago Pereira's avatar
Tiago Pereira committed
209 210
        """
        Prepare all the tensorflow variables before training.
211
        
Tiago Pereira's avatar
Tiago Pereira committed
212
        **Parameters**
213

Tiago Pereira's avatar
Tiago Pereira committed
214
            graph: Input graph for training
215

Tiago Pereira's avatar
Tiago Pereira committed
216
            optimizer: Solver
217

Tiago Pereira's avatar
Tiago Pereira committed
218
            loss: Loss function
219

Tiago Pereira's avatar
Tiago Pereira committed
220 221 222
            learning_rate: Learning rate
        """

223
        # Getting the pointer to the placeholders
224 225
        self.data_ph = self.train_data_shuffler("data", from_queue=True)
        self.label_ph = self.train_data_shuffler("label", from_queue=True)
226
                
Tiago Pereira's avatar
Tiago Pereira committed
227
        self.graph = graph
228
        self.loss = loss        
229

230 231 232
        # Attaching the loss in the graph
        self.predictor = self.loss(self.graph, self.label_ph)
        
Tiago Pereira's avatar
Tiago Pereira committed
233 234
        self.optimizer_class = optimizer
        self.learning_rate = learning_rate
235
        self.global_step = tf.contrib.framework.get_or_create_global_step()
Tiago Pereira's avatar
Tiago Pereira committed
236

237 238 239 240
        # Preparing the optimizer
        self.optimizer_class._learning_rate = self.learning_rate
        self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)

Tiago Pereira's avatar
Tiago Pereira committed
241
        # Saving all the variables
242 243
        self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(), 
                                    keep_checkpoint_every_n_hours=self.keep_checkpoint_every_n_hours)
Tiago Pereira's avatar
Tiago Pereira committed
244

245
        self.summaries_train = self.create_general_summary(self.predictor, self.graph, self.label_ph)
246

247 248
        # SAving some variables
        tf.add_to_collection("global_step", self.global_step)
Tiago Pereira's avatar
Tiago Pereira committed
249 250
        tf.add_to_collection("graph", self.graph)
        tf.add_to_collection("predictor", self.predictor)
251

Tiago Pereira's avatar
Tiago Pereira committed
252 253
        tf.add_to_collection("data_ph", self.data_ph)
        tf.add_to_collection("label_ph", self.label_ph)
254

Tiago Pereira's avatar
Tiago Pereira committed
255 256
        tf.add_to_collection("optimizer", self.optimizer)
        tf.add_to_collection("learning_rate", self.learning_rate)
257

Tiago Pereira's avatar
Tiago Pereira committed
258
        tf.add_to_collection("summaries_train", self.summaries_train)
259

260
        # Same business with the validation
261
        if self.validation_data_shuffler is not None:
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
            self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True)
            self.validation_label_ph = self.validation_data_shuffler("label", from_queue=True)

            self.validation_graph = validation_graph

            self.validation_predictor = self.loss(self.validation_graph, self.validation_label_ph)

            self.summaries_validation = self.create_general_summary(self.validation_predictor, self.validation_graph, self.validation_label_ph)
            tf.add_to_collection("summaries_validation", self.summaries_validation)
            
            tf.add_to_collection("validation_graph", self.validation_graph)
            tf.add_to_collection("validation_data_ph", self.validation_data_ph)
            tf.add_to_collection("validation_label_ph", self.validation_label_ph)

            tf.add_to_collection("validation_predictor", self.validation_predictor)
            tf.add_to_collection("summaries_validation", self.summaries_validation)
Tiago Pereira's avatar
Tiago Pereira committed
278

Tiago Pereira's avatar
Tiago Pereira committed
279
        # Creating the variables
280
        tf.local_variables_initializer().run(session=self.session)
Tiago Pereira's avatar
Tiago Pereira committed
281 282
        tf.global_variables_initializer().run(session=self.session)

283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
    def load_checkpoint(self, file_name, clear_devices=True):
        """
        Load a checkpoint

        ** Parameters **

           file_name:
                Name of the metafile to be loaded.
                If a directory is passed, the last checkpoint will be loaded

        """
        if os.path.isdir(file_name):
            checkpoint_path = tf.train.get_checkpoint_state(file_name).model_checkpoint_path
            self.saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=clear_devices)
            self.saver.restore(self.session, tf.train.latest_checkpoint(file_name))
        else:
            self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
            self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
301

302
    def create_network_from_file(self, file_name, clear_devices=True):
Tiago Pereira's avatar
Tiago Pereira committed
303
        """
Tiago Pereira's avatar
Tiago Pereira committed
304
        Bootstrap a graph from a checkpoint
Tiago Pereira's avatar
Tiago Pereira committed
305 306 307

         ** Parameters **

Tiago Pereira's avatar
Tiago Pereira committed
308
           file_name: Name of of the checkpoing
Tiago Pereira's avatar
Tiago Pereira committed
309
        """
310 311 312

        logger.info("Loading last checkpoint !!")
        self.load_checkpoint(file_name, clear_devices=True)
Tiago Pereira's avatar
Tiago Pereira committed
313 314

        # Loading training graph
Tiago Pereira's avatar
Tiago Pereira committed
315 316
        self.data_ph = tf.get_collection("data_ph")[0]
        self.label_ph = tf.get_collection("label_ph")[0]
Tiago Pereira's avatar
Tiago Pereira committed
317 318 319 320 321 322 323

        self.graph = tf.get_collection("graph")[0]
        self.predictor = tf.get_collection("predictor")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
324
        self.summaries_train = tf.get_collection("summaries_train")[0]        
Tiago Pereira's avatar
Tiago Pereira committed
325 326
        self.global_step = tf.get_collection("global_step")[0]
        self.from_scratch = False
327 328
        
        # Loading the validation bits
329
        if self.validation_data_shuffler is not None:
330 331 332 333 334 335 336 337 338
            self.summaries_validation = tf.get_collection("summaries_validation")[0]

            self.validation_graph = tf.get_collection("validation_graph")[0]
            self.validation_data_ph = tf.get_collection("validation_data_ph")[0]
            self.validation_label = tf.get_collection("validation_label_ph")[0]

            self.validation_predictor = tf.get_collection("validation_predictor")[0]
            self.summaries_validation = tf.get_collection("summaries_validation")[0]

Tiago Pereira's avatar
Tiago Pereira committed
339 340
    def __del__(self):
        tf.reset_default_graph()
341 342 343

    def get_feed_dict(self, data_shuffler):
        """
344
        Given a data shuffler prepared the dictionary to be injected in the graph
345 346

        ** Parameters **
347 348

            data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base`
349

350
        """
351
        [data, labels] = data_shuffler.get_batch()
352

Tiago Pereira's avatar
Tiago Pereira committed
353 354
        feed_dict = {self.data_ph: data,
                     self.label_ph: labels}
355 356
        return feed_dict

357
    def fit(self, step):
358 359 360 361 362 363 364 365 366
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

367 368
        if self.train_data_shuffler.prefetch:
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
369
                                                  self.learning_rate, self.summaries_train])
370 371
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
372
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
373
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
374

375 376
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
377

378
    def compute_validation(self, step):
Tiago Pereira's avatar
Tiago Pereira committed
379 380 381 382 383 384 385 386 387 388
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """

389 390 391 392 393 394 395 396
        if self.validation_data_shuffler.prefetch:
            l, lr, summary = self.session.run([self.validation_predictor,
                                               self.learning_rate, self.summaries_validation])
        else:
            feed_dict = self.get_feed_dict(self.validation_data_shuffler)
            l, lr, summary = self.session.run([self.validation_predictor,
                                               self.learning_rate, self.summaries_validation],
                                               feed_dict=feed_dict)
Tiago Pereira's avatar
Tiago Pereira committed
397

398 399
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
        self.validation_summary_writter.add_summary(summary, step)               
Tiago Pereira's avatar
Tiago Pereira committed
400

401
    def create_general_summary(self, average_loss, output, label):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
402
        """
403
        Creates a simple tensorboard summary with the value of the loss and learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
404
        """
405
        # Train summary
406
        tf.summary.scalar('loss', average_loss)
407
        tf.summary.scalar('lr', self.learning_rate)        
408 409 410 411 412

        # Computing accuracy
        correct_prediction = tf.equal(tf.argmax(output, 1), label)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar('accuracy', accuracy)        
413
        return tf.summary.merge_all()
414

415
    def start_thread(self):
Tiago Pereira's avatar
Tiago Pereira committed
416
        """
417 418 419 420
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
Tiago Pereira's avatar
Tiago Pereira committed
421
        """
422

423
        threads = []
424
        for n in range(self.train_data_shuffler.prefetch_threads):
425
            t = threading.Thread(target=self.load_and_enqueue, args=())
426 427 428 429
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
430

431
    def load_and_enqueue(self):
Tiago Pereira's avatar
Tiago Pereira committed
432
        """
433
        Injecting data in the place holder queue
434 435 436

        **Parameters**
          session: Tensorflow session
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
437

Tiago Pereira's avatar
Tiago Pereira committed
438
        """
439
        while not self.thread_pool.should_stop():
440
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
441

442 443
            data_ph = self.train_data_shuffler("data", from_queue=False)
            label_ph = self.train_data_shuffler("label", from_queue=False)
444

445 446 447 448
            feed_dict = {data_ph: train_data,
                         label_ph: train_labels}

            self.session.run(self.train_data_shuffler.enqueue_op, feed_dict=feed_dict)
449

450