Trainer.py 12.2 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
7 8 9
import threading
import os
import bob.io.base
10
import bob.core
11
from ..analyzers import SoftmaxAnalizer
12
from tensorflow.core.framework import summary_pb2
13
import time
14
from bob.learn.tensorflow.datashuffler import OnlineSampling
15
from bob.learn.tensorflow.utils.session import Session
16
from .learning_rate import constant
17

18 19 20 21 22
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

23

24 25 26
class Trainer(object):
    """
    One graph trainer.
Tiago Pereira's avatar
Tiago Pereira committed
27
    
28 29 30
    Use this trainer when your CNN is composed by one graph

    **Parameters**
31

Tiago Pereira's avatar
Tiago Pereira committed
32 33
    train_data_shuffler:
      The data shuffler used for batching data for training
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
34

Tiago Pereira's avatar
Tiago Pereira committed
35
    iterations:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
36
      Maximum number of iterations
Tiago Pereira's avatar
Tiago Pereira committed
37 38
      
    snapshot:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
39
      Will take a snapshot of the network at every `n` iterations
Tiago Pereira's avatar
Tiago Pereira committed
40 41 42
      
    validation_snapshot:
      Test with validation each `n` iterations
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
43 44 45 46

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

Tiago Pereira's avatar
Tiago Pereira committed
47 48 49
    temp_dir: str
      The output directory

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
50
    verbosity_level:
51 52

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
53

54
    def __init__(self,
Tiago Pereira's avatar
Tiago Pereira committed
55
                 train_data_shuffler,
56

57 58
                 ###### training options ##########
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
59 60
                 snapshot=500,
                 validation_snapshot=100,
61 62

                 ## Analizer
63
                 analizer=SoftmaxAnalizer(),
64

Tiago Pereira's avatar
Tiago Pereira committed
65 66
                 # Temporatu dir
                 temp_dir="cnn",
67

68
                 verbosity_level=2):
69

Tiago Pereira's avatar
Tiago Pereira committed
70
        self.train_data_shuffler = train_data_shuffler
71 72
        self.temp_dir = temp_dir

73 74
        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
75
        self.validation_snapshot = validation_snapshot
76

77 78 79
        # Training variables used in the fit
        self.summaries_train = None
        self.train_summary_writter = None
80
        self.thread_pool = None
81 82 83

        # Validation data
        self.validation_summary_writter = None
84
        self.summaries_validation = None
85

86 87
        # Analizer
        self.analizer = analizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
88
        self.global_step = None
89

90
        self.session = None
91

Tiago Pereira's avatar
Tiago Pereira committed
92 93 94 95 96 97 98 99 100 101 102
        self.graph = None
        self.loss = None
        self.predictor = None
        self.optimizer_class = None
        self.learning_rate = None
        # Training variables used in the fit
        self.optimizer = None
        self.data_ph = None
        self.label_ph = None
        self.saver = None

103 104
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago Pereira's avatar
Tiago Pereira committed
105 106 107
        # Creating the session
        self.session = Session.instance(new=True).session
        self.from_scratch = True
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
108

Tiago Pereira's avatar
Tiago Pereira committed
109 110 111 112
    def create_network_from_scratch(self,
                                    graph,
                                    optimizer=tf.train.AdamOptimizer(),
                                    loss=None,
113

Tiago Pereira's avatar
Tiago Pereira committed
114 115 116 117
                                    # Learning rate
                                    learning_rate=None,
                                    ):

Tiago Pereira's avatar
Tiago Pereira committed
118 119 120 121 122 123 124 125 126 127 128 129 130 131
        """
        Prepare all the tensorflow variables before training.
        
        **Parameters**
        
            graph: Input graph for training
            
            optimizer: Solver
            
            loss: Loss function
            
            learning_rate: Learning rate
        """

Tiago Pereira's avatar
Tiago Pereira committed
132 133 134 135
        self.data_ph = self.train_data_shuffler("data")
        self.label_ph = self.train_data_shuffler("label")
        self.graph = graph
        self.loss = loss
136
        self.predictor = self.loss(self.graph, self.train_data_shuffler("label", from_queue=False))
137

Tiago Pereira's avatar
Tiago Pereira committed
138 139
        self.optimizer_class = optimizer
        self.learning_rate = learning_rate
140

Tiago Pereira's avatar
Tiago Pereira committed
141 142
        # TODO: find an elegant way to provide this as a parameter of the trainer
        self.global_step = tf.Variable(0, trainable=False, name="global_step")
Tiago Pereira's avatar
Tiago Pereira committed
143 144 145 146

        # Saving all the variables
        self.saver = tf.train.Saver(var_list=tf.global_variables())

Tiago Pereira's avatar
Tiago Pereira committed
147
        tf.add_to_collection("global_step", self.global_step)
148

Tiago Pereira's avatar
Tiago Pereira committed
149 150
        tf.add_to_collection("graph", self.graph)
        tf.add_to_collection("predictor", self.predictor)
151

Tiago Pereira's avatar
Tiago Pereira committed
152 153
        tf.add_to_collection("data_ph", self.data_ph)
        tf.add_to_collection("label_ph", self.label_ph)
154

Tiago Pereira's avatar
Tiago Pereira committed
155 156 157 158 159
        # Preparing the optimizer
        self.optimizer_class._learning_rate = self.learning_rate
        self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
        tf.add_to_collection("optimizer", self.optimizer)
        tf.add_to_collection("learning_rate", self.learning_rate)
160

Tiago Pereira's avatar
Tiago Pereira committed
161 162
        self.summaries_train = self.create_general_summary()
        tf.add_to_collection("summaries_train", self.summaries_train)
163

164 165
        self.summaries_validation = self.create_general_summary()
        self.summaries_validation = tf.add_to_collection("summaries_validation", self.summaries_validation)
Tiago Pereira's avatar
Tiago Pereira committed
166

Tiago Pereira's avatar
Tiago Pereira committed
167 168 169
        # Creating the variables
        tf.global_variables_initializer().run(session=self.session)

170
    def create_network_from_file(self, file_name, clear_devices=True):
Tiago Pereira's avatar
Tiago Pereira committed
171
        """
Tiago Pereira's avatar
Tiago Pereira committed
172
        Bootstrap a graph from a checkpoint
Tiago Pereira's avatar
Tiago Pereira committed
173 174 175

         ** Parameters **

Tiago Pereira's avatar
Tiago Pereira committed
176
           file_name: Name of of the checkpoing
Tiago Pereira's avatar
Tiago Pereira committed
177
        """
178
        self.saver = tf.train.import_meta_graph(file_name + ".meta", clear_devices=clear_devices)
Tiago Pereira's avatar
Tiago Pereira committed
179
        self.saver.restore(self.session, file_name)
Tiago Pereira's avatar
Tiago Pereira committed
180 181

        # Loading training graph
Tiago Pereira's avatar
Tiago Pereira committed
182 183
        self.data_ph = tf.get_collection("data_ph")[0]
        self.label_ph = tf.get_collection("label_ph")[0]
Tiago Pereira's avatar
Tiago Pereira committed
184 185 186 187 188 189 190 191

        self.graph = tf.get_collection("graph")[0]
        self.predictor = tf.get_collection("predictor")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
192
        self.summaries_validation = tf.get_collection("summaries_validation")[0]
Tiago Pereira's avatar
Tiago Pereira committed
193 194 195 196 197
        self.global_step = tf.get_collection("global_step")[0]
        self.from_scratch = False

    def __del__(self):
        tf.reset_default_graph()
198 199 200

    def get_feed_dict(self, data_shuffler):
        """
201
        Given a data shuffler prepared the dictionary to be injected in the graph
202 203

        ** Parameters **
Tiago Pereira's avatar
Tiago Pereira committed
204 205
            
            data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base` 
206

207
        """
208
        [data, labels] = data_shuffler.get_batch()
209

Tiago Pereira's avatar
Tiago Pereira committed
210 211
        feed_dict = {self.data_ph: data,
                     self.label_ph: labels}
212 213
        return feed_dict

214
    def fit(self, step):
215 216 217 218 219 220 221 222 223
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

Tiago Pereira's avatar
Tiago Pereira committed
224
        if self.train_data_shuffler.prefetch:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
225
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
226
                                                  self.learning_rate, self.summaries_train])
227 228
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
229
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
230
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
231

232 233
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
234

Tiago Pereira's avatar
Tiago Pereira committed
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
    def compute_validation(self, data_shuffler, step):
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
        pass
        # Opening a new session for validation
        #feed_dict = self.get_feed_dict(data_shuffler)
        #l, summary = self.session.run(self.predictor, self.summaries_train, feed_dict=feed_dict)
        #train_summary_writter.add_summary(summary, step)


        #summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
        #self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        #logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

256
    def create_general_summary(self):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
257
        """
258
        Creates a simple tensorboard summary with the value of the loss and learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
259
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
260

261
        # Train summary
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
262
        tf.summary.scalar('loss', self.predictor)
263 264
        tf.summary.scalar('lr', self.learning_rate)
        return tf.summary.merge_all()
265

266
    def start_thread(self):
Tiago Pereira's avatar
Tiago Pereira committed
267
        """
268 269 270 271
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
Tiago Pereira's avatar
Tiago Pereira committed
272
        """
273

274
        threads = []
275
        for n in range(3):
276
            t = threading.Thread(target=self.load_and_enqueue, args=())
277 278 279 280
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
281

282
    def load_and_enqueue(self):
Tiago Pereira's avatar
Tiago Pereira committed
283
        """
284
        Injecting data in the place holder queue
285 286 287

        **Parameters**
          session: Tensorflow session
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
288

Tiago Pereira's avatar
Tiago Pereira committed
289
        """
290
        while not self.thread_pool.should_stop():
291
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
292

Tiago Pereira's avatar
Tiago Pereira committed
293 294
            feed_dict = {self.data_ph: train_data,
                         self.label_ph: train_labels}
295

Tiago Pereira's avatar
Tiago Pereira committed
296
            self.session.run(self.inputs.enqueue_op, feed_dict=feed_dict)
297

Tiago Pereira's avatar
Tiago Pereira committed
298
    def train(self, validation_data_shuffler=None):
299
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
300 301 302 303
        Train the network:

         ** Parameters **
           validation_data_shuffler: Data shuffler for validation
304 305 306 307
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
308

309
        logger.info("Initializing !!")
310 311

        # Loading a pretrained model
Tiago Pereira's avatar
Tiago Pereira committed
312
        if self.from_scratch:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
313
            start_step = 0
Tiago Pereira's avatar
Tiago Pereira committed
314 315
        else:
            start_step = self.global_step.eval(session=self.session)
316

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
317 318
        #if isinstance(train_data_shuffler, OnlineSampling):
        #    train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)
319 320

        # Start a thread to enqueue data asynchronously, and hide I/O latency.
Tiago Pereira's avatar
Tiago Pereira committed
321 322 323 324
        if self.train_data_shuffler.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            threads = self.start_thread()
325 326

        # TENSOR BOARD SUMMARY
327
        self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
328 329 330
        if validation_data_shuffler is not None:
            self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'),
                                                                    self.session.graph)
Tiago Pereira's avatar
Tiago Pereira committed
331
        # Loop for
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
332
        for step in range(start_step, self.iterations):
Tiago Pereira's avatar
Tiago Pereira committed
333
            # Run fit in the graph
334
            start = time.time()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
335
            self.fit(step)
336 337 338 339 340
            end = time.time()
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
341 342
            if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
                self.compute_validation(validation_data_shuffler, step)
343

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
344 345 346
                #if self.analizer is not None:
                #    self.validation_summary_writter.add_summary(self.analizer(
                #         validation_data_shuffler, self.architecture, self.session), step)
347 348 349 350 351

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
Tiago Pereira's avatar
Tiago Pereira committed
352
                self.saver.save(self.session, path, global_step=step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
353
                #self.architecture.save(saver, path)
354 355 356 357 358 359 360 361 362

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
Tiago Pereira's avatar
Tiago Pereira committed
363
        self.saver.save(self.session, path)
364

Tiago Pereira's avatar
Tiago Pereira committed
365
        if self.train_data_shuffler.prefetch:
366 367
            # now they should definetely stop
            self.thread_pool.request_stop()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
368
            #self.thread_pool.join(threads)