Trainer.py 11.9 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
7 8 9
import threading
import os
import bob.io.base
10
import bob.core
11
from ..analyzers import SoftmaxAnalizer
12
from tensorflow.core.framework import summary_pb2
13
import time
14
from bob.learn.tensorflow.datashuffler import OnlineSampling
15
from bob.learn.tensorflow.utils.session import Session
16
from .learning_rate import constant
17

18 19 20 21 22
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

23

24 25 26
class Trainer(object):
    """
    One graph trainer.
Tiago Pereira's avatar
Tiago Pereira committed
27
    
28 29 30
    Use this trainer when your CNN is composed by one graph

    **Parameters**
31

Tiago Pereira's avatar
Tiago Pereira committed
32 33
    train_data_shuffler:
      The data shuffler used for batching data for training
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
34

Tiago Pereira's avatar
Tiago Pereira committed
35
    iterations:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
36
      Maximum number of iterations
Tiago Pereira's avatar
Tiago Pereira committed
37 38
      
    snapshot:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
39
      Will take a snapshot of the network at every `n` iterations
Tiago Pereira's avatar
Tiago Pereira committed
40 41 42
      
    validation_snapshot:
      Test with validation each `n` iterations
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
43 44 45 46

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

Tiago Pereira's avatar
Tiago Pereira committed
47 48 49
    temp_dir: str
      The output directory

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
50
    verbosity_level:
51 52

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
53

54
    def __init__(self,
Tiago Pereira's avatar
Tiago Pereira committed
55
                 train_data_shuffler,
56

57 58
                 ###### training options ##########
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
59 60
                 snapshot=500,
                 validation_snapshot=100,
61 62

                 ## Analizer
63
                 analizer=SoftmaxAnalizer(),
64

Tiago Pereira's avatar
Tiago Pereira committed
65 66
                 # Temporatu dir
                 temp_dir="cnn",
67

68
                 verbosity_level=2):
69

Tiago Pereira's avatar
Tiago Pereira committed
70
        self.train_data_shuffler = train_data_shuffler
71 72
        self.temp_dir = temp_dir

73 74
        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
75
        self.validation_snapshot = validation_snapshot
76

77 78 79
        # Training variables used in the fit
        self.summaries_train = None
        self.train_summary_writter = None
80
        self.thread_pool = None
81 82 83 84

        # Validation data
        self.validation_summary_writter = None

85 86
        # Analizer
        self.analizer = analizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
87
        self.global_step = None
88

89
        self.session = None
90

Tiago Pereira's avatar
Tiago Pereira committed
91 92 93 94 95 96 97 98 99 100 101
        self.graph = None
        self.loss = None
        self.predictor = None
        self.optimizer_class = None
        self.learning_rate = None
        # Training variables used in the fit
        self.optimizer = None
        self.data_ph = None
        self.label_ph = None
        self.saver = None

102 103
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago Pereira's avatar
Tiago Pereira committed
104 105 106
        # Creating the session
        self.session = Session.instance(new=True).session
        self.from_scratch = True
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
107

Tiago Pereira's avatar
Tiago Pereira committed
108 109 110 111
    def create_network_from_scratch(self,
                                    graph,
                                    optimizer=tf.train.AdamOptimizer(),
                                    loss=None,
112

Tiago Pereira's avatar
Tiago Pereira committed
113 114 115 116
                                    # Learning rate
                                    learning_rate=None,
                                    ):

Tiago Pereira's avatar
Tiago Pereira committed
117 118 119 120 121 122 123 124 125 126 127 128 129 130
        """
        Prepare all the tensorflow variables before training.
        
        **Parameters**
        
            graph: Input graph for training
            
            optimizer: Solver
            
            loss: Loss function
            
            learning_rate: Learning rate
        """

Tiago Pereira's avatar
Tiago Pereira committed
131 132 133 134
        self.data_ph = self.train_data_shuffler("data")
        self.label_ph = self.train_data_shuffler("label")
        self.graph = graph
        self.loss = loss
135
        self.predictor = self.loss(self.graph, self.train_data_shuffler("label", from_queue=False))
136

Tiago Pereira's avatar
Tiago Pereira committed
137 138
        self.optimizer_class = optimizer
        self.learning_rate = learning_rate
139

Tiago Pereira's avatar
Tiago Pereira committed
140 141
        # TODO: find an elegant way to provide this as a parameter of the trainer
        self.global_step = tf.Variable(0, trainable=False, name="global_step")
Tiago Pereira's avatar
Tiago Pereira committed
142 143 144 145

        # Saving all the variables
        self.saver = tf.train.Saver(var_list=tf.global_variables())

Tiago Pereira's avatar
Tiago Pereira committed
146
        tf.add_to_collection("global_step", self.global_step)
147

Tiago Pereira's avatar
Tiago Pereira committed
148 149
        tf.add_to_collection("graph", self.graph)
        tf.add_to_collection("predictor", self.predictor)
150

Tiago Pereira's avatar
Tiago Pereira committed
151 152
        tf.add_to_collection("data_ph", self.data_ph)
        tf.add_to_collection("label_ph", self.label_ph)
153

Tiago Pereira's avatar
Tiago Pereira committed
154 155 156 157 158
        # Preparing the optimizer
        self.optimizer_class._learning_rate = self.learning_rate
        self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
        tf.add_to_collection("optimizer", self.optimizer)
        tf.add_to_collection("learning_rate", self.learning_rate)
159

Tiago Pereira's avatar
Tiago Pereira committed
160 161
        self.summaries_train = self.create_general_summary()
        tf.add_to_collection("summaries_train", self.summaries_train)
162

Tiago Pereira's avatar
Tiago Pereira committed
163

Tiago Pereira's avatar
Tiago Pereira committed
164 165 166
        # Creating the variables
        tf.global_variables_initializer().run(session=self.session)

167
    def create_network_from_file(self, file_name, clear_devices=True):
Tiago Pereira's avatar
Tiago Pereira committed
168
        """
Tiago Pereira's avatar
Tiago Pereira committed
169
        Bootstrap a graph from a checkpoint
Tiago Pereira's avatar
Tiago Pereira committed
170 171 172

         ** Parameters **

Tiago Pereira's avatar
Tiago Pereira committed
173
           file_name: Name of of the checkpoing
Tiago Pereira's avatar
Tiago Pereira committed
174
        """
175
        self.saver = tf.train.import_meta_graph(file_name + ".meta", clear_devices=clear_devices)
Tiago Pereira's avatar
Tiago Pereira committed
176
        self.saver.restore(self.session, file_name)
Tiago Pereira's avatar
Tiago Pereira committed
177 178

        # Loading training graph
Tiago Pereira's avatar
Tiago Pereira committed
179 180
        self.data_ph = tf.get_collection("data_ph")[0]
        self.label_ph = tf.get_collection("label_ph")[0]
Tiago Pereira's avatar
Tiago Pereira committed
181 182 183 184 185 186 187 188 189 190 191 192 193

        self.graph = tf.get_collection("graph")[0]
        self.predictor = tf.get_collection("predictor")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
        self.global_step = tf.get_collection("global_step")[0]
        self.from_scratch = False

    def __del__(self):
        tf.reset_default_graph()
194 195 196

    def get_feed_dict(self, data_shuffler):
        """
197
        Given a data shuffler prepared the dictionary to be injected in the graph
198 199

        ** Parameters **
Tiago Pereira's avatar
Tiago Pereira committed
200 201
            
            data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base` 
202

203
        """
204
        [data, labels] = data_shuffler.get_batch()
205

Tiago Pereira's avatar
Tiago Pereira committed
206 207
        feed_dict = {self.data_ph: data,
                     self.label_ph: labels}
208 209
        return feed_dict

210
    def fit(self, step):
211 212 213 214 215 216 217 218 219
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

Tiago Pereira's avatar
Tiago Pereira committed
220
        if self.train_data_shuffler.prefetch:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
221
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
222
                                                  self.learning_rate, self.summaries_train])
223 224
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
225
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
226
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
227

228 229
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
230

Tiago Pereira's avatar
Tiago Pereira committed
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
    def compute_validation(self, data_shuffler, step):
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
        pass
        # Opening a new session for validation
        #feed_dict = self.get_feed_dict(data_shuffler)
        #l, summary = self.session.run(self.predictor, self.summaries_train, feed_dict=feed_dict)
        #train_summary_writter.add_summary(summary, step)


        #summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
        #self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        #logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

252
    def create_general_summary(self):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
253
        """
254
        Creates a simple tensorboard summary with the value of the loss and learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
255
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
256

257
        # Train summary
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
258
        tf.summary.scalar('loss', self.predictor)
259 260
        tf.summary.scalar('lr', self.learning_rate)
        return tf.summary.merge_all()
261

262
    def start_thread(self):
Tiago Pereira's avatar
Tiago Pereira committed
263
        """
264 265 266 267
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
Tiago Pereira's avatar
Tiago Pereira committed
268
        """
269

270
        threads = []
271
        for n in range(3):
272
            t = threading.Thread(target=self.load_and_enqueue, args=())
273 274 275 276
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
277

278
    def load_and_enqueue(self):
Tiago Pereira's avatar
Tiago Pereira committed
279
        """
280
        Injecting data in the place holder queue
281 282 283

        **Parameters**
          session: Tensorflow session
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
284

Tiago Pereira's avatar
Tiago Pereira committed
285
        """
286
        while not self.thread_pool.should_stop():
287
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
288

Tiago Pereira's avatar
Tiago Pereira committed
289 290
            feed_dict = {self.data_ph: train_data,
                         self.label_ph: train_labels}
291

Tiago Pereira's avatar
Tiago Pereira committed
292
            self.session.run(self.inputs.enqueue_op, feed_dict=feed_dict)
293

Tiago Pereira's avatar
Tiago Pereira committed
294
    def train(self, validation_data_shuffler=None):
295
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
296 297 298 299
        Train the network:

         ** Parameters **
           validation_data_shuffler: Data shuffler for validation
300 301 302 303
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
304

305
        logger.info("Initializing !!")
306 307

        # Loading a pretrained model
Tiago Pereira's avatar
Tiago Pereira committed
308
        if self.from_scratch:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
309
            start_step = 0
Tiago Pereira's avatar
Tiago Pereira committed
310 311
        else:
            start_step = self.global_step.eval(session=self.session)
312

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
313 314
        #if isinstance(train_data_shuffler, OnlineSampling):
        #    train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)
315 316

        # Start a thread to enqueue data asynchronously, and hide I/O latency.
Tiago Pereira's avatar
Tiago Pereira committed
317 318 319 320
        if self.train_data_shuffler.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            threads = self.start_thread()
321 322

        # TENSOR BOARD SUMMARY
323
        self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
324 325 326
        if validation_data_shuffler is not None:
            self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'),
                                                                    self.session.graph)
Tiago Pereira's avatar
Tiago Pereira committed
327
        # Loop for
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
328
        for step in range(start_step, self.iterations):
Tiago Pereira's avatar
Tiago Pereira committed
329
            # Run fit in the graph
330
            start = time.time()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
331
            self.fit(step)
332 333 334 335 336
            end = time.time()
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
337 338
            if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
                self.compute_validation(validation_data_shuffler, step)
339

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
340 341 342
                #if self.analizer is not None:
                #    self.validation_summary_writter.add_summary(self.analizer(
                #         validation_data_shuffler, self.architecture, self.session), step)
343 344 345 346 347

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
Tiago Pereira's avatar
Tiago Pereira committed
348
                self.saver.save(self.session, path, global_step=step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
349
                #self.architecture.save(saver, path)
350 351 352 353 354 355 356 357 358

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
Tiago Pereira's avatar
Tiago Pereira committed
359
        self.saver.save(self.session, path)
360

Tiago Pereira's avatar
Tiago Pereira committed
361
        if self.train_data_shuffler.prefetch:
362 363
            # now they should definetely stop
            self.thread_pool.request_stop()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
364
            #self.thread_pool.join(threads)