Trainer.py 13.1 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
7 8 9
import threading
import os
import bob.io.base
10
import bob.core
11
from ..analyzers import SoftmaxAnalizer
12
from tensorflow.core.framework import summary_pb2
13
import time
14
from bob.learn.tensorflow.datashuffler import OnlineSampling, TFRecord
15
from bob.learn.tensorflow.utils.session import Session
16
from .learning_rate import constant
17
import time
18

19 20 21 22 23
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

24

25 26 27
class Trainer(object):
    """
    One graph trainer.
28

29 30 31
    Use this trainer when your CNN is composed by one graph

    **Parameters**
32

Tiago Pereira's avatar
Tiago Pereira committed
33 34
    train_data_shuffler:
      The data shuffler used for batching data for training
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
35

Tiago Pereira's avatar
Tiago Pereira committed
36
    iterations:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
37
      Maximum number of iterations
38

Tiago Pereira's avatar
Tiago Pereira committed
39
    snapshot:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
40
      Will take a snapshot of the network at every `n` iterations
41

Tiago Pereira's avatar
Tiago Pereira committed
42 43
    validation_snapshot:
      Test with validation each `n` iterations
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
44 45 46 47

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

Tiago Pereira's avatar
Tiago Pereira committed
48 49 50
    temp_dir: str
      The output directory

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
51
    verbosity_level:
52 53

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
54

55
    def __init__(self,
Tiago Pereira's avatar
Tiago Pereira committed
56
                 train_data_shuffler,
57

58 59
                 ###### training options ##########
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
60 61
                 snapshot=500,
                 validation_snapshot=100,
62 63

                 ## Analizer
64
                 analizer=SoftmaxAnalizer(),
65

Tiago Pereira's avatar
Tiago Pereira committed
66 67
                 # Temporatu dir
                 temp_dir="cnn",
68

69
                 verbosity_level=2):
70

Tiago Pereira's avatar
Tiago Pereira committed
71
        self.train_data_shuffler = train_data_shuffler
72 73
        self.temp_dir = temp_dir

74 75
        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
76
        self.validation_snapshot = validation_snapshot
77

78 79 80
        # Training variables used in the fit
        self.summaries_train = None
        self.train_summary_writter = None
81
        self.thread_pool = None
82 83 84

        # Validation data
        self.validation_summary_writter = None
85
        self.summaries_validation = None
86

87 88
        # Analizer
        self.analizer = analizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
89
        self.global_step = None
90

91
        self.session = None
92

Tiago Pereira's avatar
Tiago Pereira committed
93 94 95 96 97 98 99 100 101 102 103
        self.graph = None
        self.loss = None
        self.predictor = None
        self.optimizer_class = None
        self.learning_rate = None
        # Training variables used in the fit
        self.optimizer = None
        self.data_ph = None
        self.label_ph = None
        self.saver = None

104 105
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago Pereira's avatar
Tiago Pereira committed
106 107 108
        # Creating the session
        self.session = Session.instance(new=True).session
        self.from_scratch = True
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
109

Tiago Pereira's avatar
Tiago Pereira committed
110 111 112 113
    def create_network_from_scratch(self,
                                    graph,
                                    optimizer=tf.train.AdamOptimizer(),
                                    loss=None,
114

Tiago Pereira's avatar
Tiago Pereira committed
115 116 117 118
                                    # Learning rate
                                    learning_rate=None,
                                    ):

Tiago Pereira's avatar
Tiago Pereira committed
119 120
        """
        Prepare all the tensorflow variables before training.
121

Tiago Pereira's avatar
Tiago Pereira committed
122
        **Parameters**
123

Tiago Pereira's avatar
Tiago Pereira committed
124
            graph: Input graph for training
125

Tiago Pereira's avatar
Tiago Pereira committed
126
            optimizer: Solver
127

Tiago Pereira's avatar
Tiago Pereira committed
128
            loss: Loss function
129

Tiago Pereira's avatar
Tiago Pereira committed
130 131 132
            learning_rate: Learning rate
        """

133 134
        self.data_ph = self.train_data_shuffler("data", from_queue=True)
        self.label_ph = self.train_data_shuffler("label", from_queue=True)
Tiago Pereira's avatar
Tiago Pereira committed
135 136
        self.graph = graph
        self.loss = loss
137 138 139 140 141 142 143 144
        
        #TODO: DEBUG
        #self.predictor = self.loss(self.graph, self.label_ph)
        
        self.predictor = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.graph,
                                                                        labels=self.label_ph)
        self.loss = tf.reduce_mean(self.predictor)

145

Tiago Pereira's avatar
Tiago Pereira committed
146 147
        self.optimizer_class = optimizer
        self.learning_rate = learning_rate
148

149
        self.global_step = tf.contrib.framework.get_or_create_global_step()
Tiago Pereira's avatar
Tiago Pereira committed
150 151

        # Saving all the variables
152
        self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables())
Tiago Pereira's avatar
Tiago Pereira committed
153

Tiago Pereira's avatar
Tiago Pereira committed
154
        tf.add_to_collection("global_step", self.global_step)
155

Tiago Pereira's avatar
Tiago Pereira committed
156 157
        tf.add_to_collection("graph", self.graph)
        tf.add_to_collection("predictor", self.predictor)
158

Tiago Pereira's avatar
Tiago Pereira committed
159 160
        tf.add_to_collection("data_ph", self.data_ph)
        tf.add_to_collection("label_ph", self.label_ph)
161

Tiago Pereira's avatar
Tiago Pereira committed
162 163 164 165 166
        # Preparing the optimizer
        self.optimizer_class._learning_rate = self.learning_rate
        self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
        tf.add_to_collection("optimizer", self.optimizer)
        tf.add_to_collection("learning_rate", self.learning_rate)
167

Tiago Pereira's avatar
Tiago Pereira committed
168 169
        self.summaries_train = self.create_general_summary()
        tf.add_to_collection("summaries_train", self.summaries_train)
170

171 172
        self.summaries_validation = self.create_general_summary()
        self.summaries_validation = tf.add_to_collection("summaries_validation", self.summaries_validation)
Tiago Pereira's avatar
Tiago Pereira committed
173

Tiago Pereira's avatar
Tiago Pereira committed
174
        # Creating the variables
175
        #tf.local_variables_initializer().run(session=self.session)
Tiago Pereira's avatar
Tiago Pereira committed
176 177
        tf.global_variables_initializer().run(session=self.session)

178
    def create_network_from_file(self, file_name, clear_devices=True):
Tiago Pereira's avatar
Tiago Pereira committed
179
        """
Tiago Pereira's avatar
Tiago Pereira committed
180
        Bootstrap a graph from a checkpoint
Tiago Pereira's avatar
Tiago Pereira committed
181 182 183

         ** Parameters **

Tiago Pereira's avatar
Tiago Pereira committed
184
           file_name: Name of of the checkpoing
Tiago Pereira's avatar
Tiago Pereira committed
185
        """
186
        self.saver = tf.train.import_meta_graph(file_name + ".meta", clear_devices=clear_devices)
Tiago Pereira's avatar
Tiago Pereira committed
187
        self.saver.restore(self.session, file_name)
Tiago Pereira's avatar
Tiago Pereira committed
188 189

        # Loading training graph
Tiago Pereira's avatar
Tiago Pereira committed
190 191
        self.data_ph = tf.get_collection("data_ph")[0]
        self.label_ph = tf.get_collection("label_ph")[0]
Tiago Pereira's avatar
Tiago Pereira committed
192 193 194 195 196 197 198 199

        self.graph = tf.get_collection("graph")[0]
        self.predictor = tf.get_collection("predictor")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
200
        self.summaries_validation = tf.get_collection("summaries_validation")[0]
Tiago Pereira's avatar
Tiago Pereira committed
201 202 203 204 205
        self.global_step = tf.get_collection("global_step")[0]
        self.from_scratch = False

    def __del__(self):
        tf.reset_default_graph()
206 207 208

    def get_feed_dict(self, data_shuffler):
        """
209
        Given a data shuffler prepared the dictionary to be injected in the graph
210 211

        ** Parameters **
212 213

            data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base`
214

215
        """
216
        [data, labels] = data_shuffler.get_batch()
217

Tiago Pereira's avatar
Tiago Pereira committed
218 219
        feed_dict = {self.data_ph: data,
                     self.label_ph: labels}
220 221
        return feed_dict

222
    def fit(self, step):
223 224 225 226 227 228 229 230 231
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

232
        if self.train_data_shuffler.prefetch or isinstance(self.train_data_shuffler, TFRecord):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
233
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
234
                                                  self.learning_rate, self.summaries_train])
235 236
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
237
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
238
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
239

240 241
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
242

Tiago Pereira's avatar
Tiago Pereira committed
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
    def compute_validation(self, data_shuffler, step):
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
        pass
        # Opening a new session for validation
        #feed_dict = self.get_feed_dict(data_shuffler)
        #l, summary = self.session.run(self.predictor, self.summaries_train, feed_dict=feed_dict)
        #train_summary_writter.add_summary(summary, step)


        #summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
        #self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        #logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

264
    def create_general_summary(self):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
265
        """
266
        Creates a simple tensorboard summary with the value of the loss and learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
267
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
268

269
        # Train summary
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
270
        tf.summary.scalar('loss', self.predictor)
271 272
        tf.summary.scalar('lr', self.learning_rate)
        return tf.summary.merge_all()
273

274
    def start_thread(self):
Tiago Pereira's avatar
Tiago Pereira committed
275
        """
276 277 278 279
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
Tiago Pereira's avatar
Tiago Pereira committed
280
        """
281

282
        threads = []
283
        for n in range(self.train_data_shuffler.prefetch_threads):
284
            t = threading.Thread(target=self.load_and_enqueue, args=())
285 286 287 288
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
289

290
    def load_and_enqueue(self):
Tiago Pereira's avatar
Tiago Pereira committed
291
        """
292
        Injecting data in the place holder queue
293 294 295

        **Parameters**
          session: Tensorflow session
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
296

Tiago Pereira's avatar
Tiago Pereira committed
297
        """
298
        while not self.thread_pool.should_stop():
299
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
300

301 302
            data_ph = self.train_data_shuffler("data", from_queue=False)
            label_ph = self.train_data_shuffler("label", from_queue=False)
303

304 305 306 307
            feed_dict = {data_ph: train_data,
                         label_ph: train_labels}

            self.session.run(self.train_data_shuffler.enqueue_op, feed_dict=feed_dict)
308

Tiago Pereira's avatar
Tiago Pereira committed
309
    def train(self, validation_data_shuffler=None):
310
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
311 312 313 314
        Train the network:

         ** Parameters **
           validation_data_shuffler: Data shuffler for validation
315 316 317 318
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
319

320
        logger.info("Initializing !!")
321 322

        # Loading a pretrained model
Tiago Pereira's avatar
Tiago Pereira committed
323
        if self.from_scratch:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
324
            start_step = 0
Tiago Pereira's avatar
Tiago Pereira committed
325 326
        else:
            start_step = self.global_step.eval(session=self.session)
327

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
328 329
        #if isinstance(train_data_shuffler, OnlineSampling):
        #    train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)
330

331
        # Start a thread to enqueue data asynchronously, and hide I/O latency.        
Tiago Pereira's avatar
Tiago Pereira committed
332 333 334 335
        if self.train_data_shuffler.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            threads = self.start_thread()
336 337 338 339
            #time.sleep(20) # As suggested in https://stackoverflow.com/questions/39840323/benchmark-of-howto-reading-data/39842628#39842628
            
            
        # TODO: JUST FOR TESTING THE INTEGRATION
340
        #import ipdb; ipdb.set_trace();
341
        if isinstance(self.train_data_shuffler, TFRecord):
342
            tf.local_variables_initializer().run(session=self.session)
343 344 345
            self.thread_pool = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
        
346 347

        # TENSOR BOARD SUMMARY
348
        self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
349 350 351
        if validation_data_shuffler is not None:
            self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'),
                                                                    self.session.graph)
Tiago Pereira's avatar
Tiago Pereira committed
352
        # Loop for
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
353
        for step in range(start_step, self.iterations):
Tiago Pereira's avatar
Tiago Pereira committed
354
            # Run fit in the graph
355
            start = time.time()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
356
            self.fit(step)
357
            end = time.time()
358

359 360 361 362
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
363 364
            if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
                self.compute_validation(validation_data_shuffler, step)
365

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
366 367 368
                #if self.analizer is not None:
                #    self.validation_summary_writter.add_summary(self.analizer(
                #         validation_data_shuffler, self.architecture, self.session), step)
369 370 371 372 373

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
Tiago Pereira's avatar
Tiago Pereira committed
374
                self.saver.save(self.session, path, global_step=step)
375 376 377 378 379 380 381 382 383

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
Tiago Pereira's avatar
Tiago Pereira committed
384
        self.saver.save(self.session, path)
385

386
        if self.train_data_shuffler.prefetch or isinstance(self.train_data_shuffler, TFRecord):
387 388
            # now they should definetely stop
            self.thread_pool.request_stop()
389
            self.thread_pool.join(threads)