Trainer.py 15.4 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
7 8 9
import threading
import os
import bob.io.base
10
import bob.core
11
from ..analyzers import SoftmaxAnalizer
12
from tensorflow.core.framework import summary_pb2
13
import time
14
from bob.learn.tensorflow.datashuffler import OnlineSampling, TFRecord
15
from bob.learn.tensorflow.utils.session import Session
16
from .learning_rate import constant
17
import time
18

19 20 21 22 23
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

24

25 26 27 28 29 30
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
31

Tiago Pereira's avatar
Tiago Pereira committed
32 33
    train_data_shuffler:
      The data shuffler used for batching data for training
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
34

Tiago Pereira's avatar
Tiago Pereira committed
35
    iterations:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
36
      Maximum number of iterations
37

Tiago Pereira's avatar
Tiago Pereira committed
38
    snapshot:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
39
      Will take a snapshot of the network at every `n` iterations
40

Tiago Pereira's avatar
Tiago Pereira committed
41 42
    validation_snapshot:
      Test with validation each `n` iterations
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
43 44 45 46

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

Tiago Pereira's avatar
Tiago Pereira committed
47 48 49
    temp_dir: str
      The output directory

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
50
    verbosity_level:
51 52

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
53

54
    def __init__(self,
Tiago Pereira's avatar
Tiago Pereira committed
55
                 train_data_shuffler,
56
                 validation_data_shuffler=None,
57

58 59
                 ###### training options ##########
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
60 61
                 snapshot=500,
                 validation_snapshot=100,
62
                 keep_checkpoint_every_n_hours=2,
63 64

                 ## Analizer
65
                 analizer=SoftmaxAnalizer(),
66

Tiago Pereira's avatar
Tiago Pereira committed
67 68
                 # Temporatu dir
                 temp_dir="cnn",
69

70
                 verbosity_level=2):
71

Tiago Pereira's avatar
Tiago Pereira committed
72
        self.train_data_shuffler = train_data_shuffler
73

74 75
        self.temp_dir = temp_dir

76 77
        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
78
        self.validation_snapshot = validation_snapshot
79
        self.keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
80

81 82 83
        # Training variables used in the fit
        self.summaries_train = None
        self.train_summary_writter = None
84
        self.thread_pool = None
85 86 87

        # Validation data
        self.validation_summary_writter = None
88
        self.summaries_validation = None
89
        self.validation_data_shuffler = validation_data_shuffler
90

91 92
        # Analizer
        self.analizer = analizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
93
        self.global_step = None
94

95
        self.session = None
96

Tiago Pereira's avatar
Tiago Pereira committed
97
        self.graph = None
98 99
        self.validation_graph = None
                
Tiago Pereira's avatar
Tiago Pereira committed
100
        self.loss = None
101
        
Tiago Pereira's avatar
Tiago Pereira committed
102
        self.predictor = None
103 104
        self.validation_predictor = None        
        
Tiago Pereira's avatar
Tiago Pereira committed
105 106
        self.optimizer_class = None
        self.learning_rate = None
107

Tiago Pereira's avatar
Tiago Pereira committed
108 109
        # Training variables used in the fit
        self.optimizer = None
110
        
Tiago Pereira's avatar
Tiago Pereira committed
111 112
        self.data_ph = None
        self.label_ph = None
113 114 115 116
        
        self.validation_data_ph = None
        self.validation_label_ph = None
        
Tiago Pereira's avatar
Tiago Pereira committed
117 118
        self.saver = None

119 120
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago Pereira's avatar
Tiago Pereira committed
121 122 123
        # Creating the session
        self.session = Session.instance(new=True).session
        self.from_scratch = True
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
        
        
    def train(self):
        """
        Train the network        
        Here we basically have the loop for that takes your graph and do a sequence of session.run
        """

        # Creating directories
        bob.io.base.create_directories_safe(self.temp_dir)
        logger.info("Initializing !!")

        # Loading a pretrained model
        if self.from_scratch:
            start_step = 0
        else:
            start_step = self.global_step.eval(session=self.session)

        # TODO: Put this back as soon as possible
        #if isinstance(train_data_shuffler, OnlineSampling):
        #    train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)

        # Start a thread to enqueue data asynchronously, and hide I/O latency.        
        if self.train_data_shuffler.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            # In case you have your own queue
            if not isinstance(self.train_data_shuffler, TFRecord):
                threads = self.start_thread()

        # Bootstrapping the summary writters
        self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
        if self.validation_data_shuffler is not None:
            self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'),
                                                                    self.session.graph)

        ######################### Loop for #################
        for step in range(start_step, start_step+self.iterations):
            # Run fit in the graph
            start = time.time()
            self.fit(step)
            end = time.time()

            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
            if self.validation_data_shuffler is not None and step % self.validation_snapshot == 0:
                self.compute_validation(step)

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                self.saver.save(self.session, path, global_step=step)

        # Running validation for the last time
        if self.validation_data_shuffler is not None:
            self.compute_validation(step)
            
        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if self.validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
        self.saver.save(self.session, path)

        if self.train_data_shuffler.prefetch or isinstance(self.train_data_shuffler, TFRecord):
            # now they should definetely stop
            self.thread_pool.request_stop()
            #if not isinstance(self.train_data_shuffler, TFRecord):
            #    self.thread_pool.join(threads)        

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
200

Tiago Pereira's avatar
Tiago Pereira committed
201 202
    def create_network_from_scratch(self,
                                    graph,
203
                                    validation_graph=None,
Tiago Pereira's avatar
Tiago Pereira committed
204 205
                                    optimizer=tf.train.AdamOptimizer(),
                                    loss=None,
206

Tiago Pereira's avatar
Tiago Pereira committed
207 208 209 210
                                    # Learning rate
                                    learning_rate=None,
                                    ):

Tiago Pereira's avatar
Tiago Pereira committed
211 212
        """
        Prepare all the tensorflow variables before training.
213
        
Tiago Pereira's avatar
Tiago Pereira committed
214
        **Parameters**
215

Tiago Pereira's avatar
Tiago Pereira committed
216
            graph: Input graph for training
217

Tiago Pereira's avatar
Tiago Pereira committed
218
            optimizer: Solver
219

Tiago Pereira's avatar
Tiago Pereira committed
220
            loss: Loss function
221

Tiago Pereira's avatar
Tiago Pereira committed
222 223 224
            learning_rate: Learning rate
        """

225 226 227 228
       # Putting together the training data + graph  + loss


        # Getting the pointer to the placeholders
229 230
        self.data_ph = self.train_data_shuffler("data", from_queue=True)
        self.label_ph = self.train_data_shuffler("label", from_queue=True)
231
                
Tiago Pereira's avatar
Tiago Pereira committed
232
        self.graph = graph
233
        self.loss = loss        
234

235 236 237
        # Attaching the loss in the graph
        self.predictor = self.loss(self.graph, self.label_ph)
        
Tiago Pereira's avatar
Tiago Pereira committed
238 239
        self.optimizer_class = optimizer
        self.learning_rate = learning_rate
240
        self.global_step = tf.contrib.framework.get_or_create_global_step()
Tiago Pereira's avatar
Tiago Pereira committed
241

242 243 244 245 246
        # Preparing the optimizer
        self.optimizer_class._learning_rate = self.learning_rate
        self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)


Tiago Pereira's avatar
Tiago Pereira committed
247
        # Saving all the variables
248 249
        self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(), 
                                    keep_checkpoint_every_n_hours=self.keep_checkpoint_every_n_hours)
Tiago Pereira's avatar
Tiago Pereira committed
250

251
        self.summaries_train = self.create_general_summary(self.predictor, self.graph, self.label_ph)
252

253 254
        # SAving some variables
        tf.add_to_collection("global_step", self.global_step)
Tiago Pereira's avatar
Tiago Pereira committed
255 256
        tf.add_to_collection("graph", self.graph)
        tf.add_to_collection("predictor", self.predictor)
257

Tiago Pereira's avatar
Tiago Pereira committed
258 259
        tf.add_to_collection("data_ph", self.data_ph)
        tf.add_to_collection("label_ph", self.label_ph)
260

Tiago Pereira's avatar
Tiago Pereira committed
261 262
        tf.add_to_collection("optimizer", self.optimizer)
        tf.add_to_collection("learning_rate", self.learning_rate)
263

Tiago Pereira's avatar
Tiago Pereira committed
264
        tf.add_to_collection("summaries_train", self.summaries_train)
265

266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
        # Same business with the validation
        if(self.validation_data_shuffler is not None):
            self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True)
            self.validation_label_ph = self.validation_data_shuffler("label", from_queue=True)

            self.validation_graph = validation_graph

            self.validation_predictor = self.loss(self.validation_graph, self.validation_label_ph)

            self.summaries_validation = self.create_general_summary(self.validation_predictor, self.validation_graph, self.validation_label_ph)
            tf.add_to_collection("summaries_validation", self.summaries_validation)
            
            tf.add_to_collection("validation_graph", self.validation_graph)
            tf.add_to_collection("validation_data_ph", self.validation_data_ph)
            tf.add_to_collection("validation_label_ph", self.validation_label_ph)

            tf.add_to_collection("validation_predictor", self.validation_predictor)
            tf.add_to_collection("summaries_validation", self.summaries_validation)
Tiago Pereira's avatar
Tiago Pereira committed
284

Tiago Pereira's avatar
Tiago Pereira committed
285
        # Creating the variables
286
        tf.local_variables_initializer().run(session=self.session)
Tiago Pereira's avatar
Tiago Pereira committed
287 288
        tf.global_variables_initializer().run(session=self.session)

289

290
    def create_network_from_file(self, file_name, clear_devices=True):
Tiago Pereira's avatar
Tiago Pereira committed
291
        """
Tiago Pereira's avatar
Tiago Pereira committed
292
        Bootstrap a graph from a checkpoint
Tiago Pereira's avatar
Tiago Pereira committed
293 294 295

         ** Parameters **

Tiago Pereira's avatar
Tiago Pereira committed
296
           file_name: Name of of the checkpoing
Tiago Pereira's avatar
Tiago Pereira committed
297
        """
298 299 300
        #self.saver = tf.train.import_meta_graph(file_name + ".meta", clear_devices=clear_devices)
        self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)        
        self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
Tiago Pereira's avatar
Tiago Pereira committed
301 302

        # Loading training graph
Tiago Pereira's avatar
Tiago Pereira committed
303 304
        self.data_ph = tf.get_collection("data_ph")[0]
        self.label_ph = tf.get_collection("label_ph")[0]
Tiago Pereira's avatar
Tiago Pereira committed
305 306 307 308 309 310 311

        self.graph = tf.get_collection("graph")[0]
        self.predictor = tf.get_collection("predictor")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
312
        self.summaries_train = tf.get_collection("summaries_train")[0]        
Tiago Pereira's avatar
Tiago Pereira committed
313 314
        self.global_step = tf.get_collection("global_step")[0]
        self.from_scratch = False
315 316 317 318 319 320 321 322 323 324 325 326 327
        
        # Loading the validation bits
        if(self.validation_data_shuffler is not None):
            self.summaries_validation = tf.get_collection("summaries_validation")[0]


            self.validation_graph = tf.get_collection("validation_graph")[0]
            self.validation_data_ph = tf.get_collection("validation_data_ph")[0]
            self.validation_label = tf.get_collection("validation_label_ph")[0]

            self.validation_predictor = tf.get_collection("validation_predictor")[0]
            self.summaries_validation = tf.get_collection("summaries_validation")[0]

Tiago Pereira's avatar
Tiago Pereira committed
328 329 330

    def __del__(self):
        tf.reset_default_graph()
331 332 333

    def get_feed_dict(self, data_shuffler):
        """
334
        Given a data shuffler prepared the dictionary to be injected in the graph
335 336

        ** Parameters **
337 338

            data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base`
339

340
        """
341
        [data, labels] = data_shuffler.get_batch()
342

Tiago Pereira's avatar
Tiago Pereira committed
343 344
        feed_dict = {self.data_ph: data,
                     self.label_ph: labels}
345 346
        return feed_dict

347
    def fit(self, step):
348 349 350 351 352 353 354 355 356
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

357 358
        if self.train_data_shuffler.prefetch:
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
359
                                                  self.learning_rate, self.summaries_train])
360 361
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
362
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
363
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
364

365 366
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
367

368
    def compute_validation(self, step):
Tiago Pereira's avatar
Tiago Pereira committed
369 370 371 372 373 374 375 376 377 378
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """

379 380 381 382 383 384 385 386
        if self.validation_data_shuffler.prefetch:
            l, lr, summary = self.session.run([self.validation_predictor,
                                               self.learning_rate, self.summaries_validation])
        else:
            feed_dict = self.get_feed_dict(self.validation_data_shuffler)
            l, lr, summary = self.session.run([self.validation_predictor,
                                               self.learning_rate, self.summaries_validation],
                                               feed_dict=feed_dict)
Tiago Pereira's avatar
Tiago Pereira committed
387

388 389
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
        self.validation_summary_writter.add_summary(summary, step)               
Tiago Pereira's avatar
Tiago Pereira committed
390

391
    def create_general_summary(self, average_loss, output, label):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
392
        """
393
        Creates a simple tensorboard summary with the value of the loss and learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
394
        """
395
        # Train summary
396
        tf.summary.scalar('loss', average_loss)
397
        tf.summary.scalar('lr', self.learning_rate)        
398 399 400 401 402

        # Computing accuracy
        correct_prediction = tf.equal(tf.argmax(output, 1), label)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar('accuracy', accuracy)        
403
        return tf.summary.merge_all()
404

405
    def start_thread(self):
Tiago Pereira's avatar
Tiago Pereira committed
406
        """
407 408 409 410
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
Tiago Pereira's avatar
Tiago Pereira committed
411
        """
412

413
        threads = []
414
        for n in range(self.train_data_shuffler.prefetch_threads):
415
            t = threading.Thread(target=self.load_and_enqueue, args=())
416 417 418 419
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
420

421
    def load_and_enqueue(self):
Tiago Pereira's avatar
Tiago Pereira committed
422
        """
423
        Injecting data in the place holder queue
424 425 426

        **Parameters**
          session: Tensorflow session
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
427

Tiago Pereira's avatar
Tiago Pereira committed
428
        """
429
        while not self.thread_pool.should_stop():
430
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
431

432 433
            data_ph = self.train_data_shuffler("data", from_queue=False)
            label_ph = self.train_data_shuffler("label", from_queue=False)
434

435 436 437 438
            feed_dict = {data_ph: train_data,
                         label_ph: train_labels}

            self.session.run(self.train_data_shuffler.enqueue_op, feed_dict=feed_dict)
439

440