Trainer.py 12.5 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
7 8 9
import threading
import os
import bob.io.base
10
import bob.core
11
from ..analyzers import SoftmaxAnalizer
12
from tensorflow.core.framework import summary_pb2
13
import time
14
from bob.learn.tensorflow.datashuffler import OnlineSampling
15
from bob.learn.tensorflow.utils.session import Session
16
from .learning_rate import constant
17
import time
18

19 20 21 22 23
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

24

25 26 27
class Trainer(object):
    """
    One graph trainer.
Tiago Pereira's avatar
Tiago Pereira committed
28
    
29 30 31
    Use this trainer when your CNN is composed by one graph

    **Parameters**
32

Tiago Pereira's avatar
Tiago Pereira committed
33 34
    train_data_shuffler:
      The data shuffler used for batching data for training
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
35

Tiago Pereira's avatar
Tiago Pereira committed
36
    iterations:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
37
      Maximum number of iterations
Tiago Pereira's avatar
Tiago Pereira committed
38 39
      
    snapshot:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
40
      Will take a snapshot of the network at every `n` iterations
Tiago Pereira's avatar
Tiago Pereira committed
41 42 43
      
    validation_snapshot:
      Test with validation each `n` iterations
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
44 45 46 47

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

Tiago Pereira's avatar
Tiago Pereira committed
48 49 50
    temp_dir: str
      The output directory

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
51
    verbosity_level:
52 53

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
54

55
    def __init__(self,
Tiago Pereira's avatar
Tiago Pereira committed
56
                 train_data_shuffler,
57

58 59
                 ###### training options ##########
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
60 61
                 snapshot=500,
                 validation_snapshot=100,
62 63

                 ## Analizer
64
                 analizer=SoftmaxAnalizer(),
65

Tiago Pereira's avatar
Tiago Pereira committed
66 67
                 # Temporatu dir
                 temp_dir="cnn",
68

69
                 verbosity_level=2):
70

Tiago Pereira's avatar
Tiago Pereira committed
71
        self.train_data_shuffler = train_data_shuffler
72 73
        self.temp_dir = temp_dir

74 75
        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
76
        self.validation_snapshot = validation_snapshot
77

78 79 80
        # Training variables used in the fit
        self.summaries_train = None
        self.train_summary_writter = None
81
        self.thread_pool = None
82 83 84

        # Validation data
        self.validation_summary_writter = None
85
        self.summaries_validation = None
86

87 88
        # Analizer
        self.analizer = analizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
89
        self.global_step = None
90

91
        self.session = None
92

Tiago Pereira's avatar
Tiago Pereira committed
93 94 95 96 97 98 99 100 101 102 103
        self.graph = None
        self.loss = None
        self.predictor = None
        self.optimizer_class = None
        self.learning_rate = None
        # Training variables used in the fit
        self.optimizer = None
        self.data_ph = None
        self.label_ph = None
        self.saver = None

104 105
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago Pereira's avatar
Tiago Pereira committed
106 107 108
        # Creating the session
        self.session = Session.instance(new=True).session
        self.from_scratch = True
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
109

Tiago Pereira's avatar
Tiago Pereira committed
110 111 112 113
    def create_network_from_scratch(self,
                                    graph,
                                    optimizer=tf.train.AdamOptimizer(),
                                    loss=None,
114

Tiago Pereira's avatar
Tiago Pereira committed
115 116 117 118
                                    # Learning rate
                                    learning_rate=None,
                                    ):

Tiago Pereira's avatar
Tiago Pereira committed
119 120 121 122 123 124 125 126 127 128 129 130 131 132
        """
        Prepare all the tensorflow variables before training.
        
        **Parameters**
        
            graph: Input graph for training
            
            optimizer: Solver
            
            loss: Loss function
            
            learning_rate: Learning rate
        """

133 134
        self.data_ph = self.train_data_shuffler("data", from_queue=True)
        self.label_ph = self.train_data_shuffler("label", from_queue=True)
Tiago Pereira's avatar
Tiago Pereira committed
135 136
        self.graph = graph
        self.loss = loss
137
        self.predictor = self.loss(self.graph, self.label_ph)
138

Tiago Pereira's avatar
Tiago Pereira committed
139 140
        self.optimizer_class = optimizer
        self.learning_rate = learning_rate
141

Tiago Pereira's avatar
Tiago Pereira committed
142 143
        # TODO: find an elegant way to provide this as a parameter of the trainer
        self.global_step = tf.Variable(0, trainable=False, name="global_step")
Tiago Pereira's avatar
Tiago Pereira committed
144 145 146 147

        # Saving all the variables
        self.saver = tf.train.Saver(var_list=tf.global_variables())

Tiago Pereira's avatar
Tiago Pereira committed
148
        tf.add_to_collection("global_step", self.global_step)
149

Tiago Pereira's avatar
Tiago Pereira committed
150 151
        tf.add_to_collection("graph", self.graph)
        tf.add_to_collection("predictor", self.predictor)
152

Tiago Pereira's avatar
Tiago Pereira committed
153 154
        tf.add_to_collection("data_ph", self.data_ph)
        tf.add_to_collection("label_ph", self.label_ph)
155

Tiago Pereira's avatar
Tiago Pereira committed
156 157 158 159 160
        # Preparing the optimizer
        self.optimizer_class._learning_rate = self.learning_rate
        self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
        tf.add_to_collection("optimizer", self.optimizer)
        tf.add_to_collection("learning_rate", self.learning_rate)
161

Tiago Pereira's avatar
Tiago Pereira committed
162 163
        self.summaries_train = self.create_general_summary()
        tf.add_to_collection("summaries_train", self.summaries_train)
164

165 166
        self.summaries_validation = self.create_general_summary()
        self.summaries_validation = tf.add_to_collection("summaries_validation", self.summaries_validation)
Tiago Pereira's avatar
Tiago Pereira committed
167

Tiago Pereira's avatar
Tiago Pereira committed
168 169 170
        # Creating the variables
        tf.global_variables_initializer().run(session=self.session)

171
    def create_network_from_file(self, file_name, clear_devices=True):
Tiago Pereira's avatar
Tiago Pereira committed
172
        """
Tiago Pereira's avatar
Tiago Pereira committed
173
        Bootstrap a graph from a checkpoint
Tiago Pereira's avatar
Tiago Pereira committed
174 175 176

         ** Parameters **

Tiago Pereira's avatar
Tiago Pereira committed
177
           file_name: Name of of the checkpoing
Tiago Pereira's avatar
Tiago Pereira committed
178
        """
179
        self.saver = tf.train.import_meta_graph(file_name + ".meta", clear_devices=clear_devices)
Tiago Pereira's avatar
Tiago Pereira committed
180
        self.saver.restore(self.session, file_name)
Tiago Pereira's avatar
Tiago Pereira committed
181 182

        # Loading training graph
Tiago Pereira's avatar
Tiago Pereira committed
183 184
        self.data_ph = tf.get_collection("data_ph")[0]
        self.label_ph = tf.get_collection("label_ph")[0]
Tiago Pereira's avatar
Tiago Pereira committed
185 186 187 188 189 190 191 192

        self.graph = tf.get_collection("graph")[0]
        self.predictor = tf.get_collection("predictor")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
193
        self.summaries_validation = tf.get_collection("summaries_validation")[0]
Tiago Pereira's avatar
Tiago Pereira committed
194 195 196 197 198
        self.global_step = tf.get_collection("global_step")[0]
        self.from_scratch = False

    def __del__(self):
        tf.reset_default_graph()
199 200 201

    def get_feed_dict(self, data_shuffler):
        """
202
        Given a data shuffler prepared the dictionary to be injected in the graph
203 204

        ** Parameters **
Tiago Pereira's avatar
Tiago Pereira committed
205 206
            
            data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base` 
207

208
        """
209
        [data, labels] = data_shuffler.get_batch()
210

Tiago Pereira's avatar
Tiago Pereira committed
211 212
        feed_dict = {self.data_ph: data,
                     self.label_ph: labels}
213 214
        return feed_dict

215
    def fit(self, step):
216 217 218 219 220 221 222 223 224
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

Tiago Pereira's avatar
Tiago Pereira committed
225
        if self.train_data_shuffler.prefetch:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
226
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
227
                                                  self.learning_rate, self.summaries_train])
228 229
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
230
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
231
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
232

233 234
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
235

Tiago Pereira's avatar
Tiago Pereira committed
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
    def compute_validation(self, data_shuffler, step):
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
        pass
        # Opening a new session for validation
        #feed_dict = self.get_feed_dict(data_shuffler)
        #l, summary = self.session.run(self.predictor, self.summaries_train, feed_dict=feed_dict)
        #train_summary_writter.add_summary(summary, step)


        #summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
        #self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        #logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

257
    def create_general_summary(self):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
258
        """
259
        Creates a simple tensorboard summary with the value of the loss and learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
260
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
261

262
        # Train summary
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
263
        tf.summary.scalar('loss', self.predictor)
264 265
        tf.summary.scalar('lr', self.learning_rate)
        return tf.summary.merge_all()
266

267
    def start_thread(self):
Tiago Pereira's avatar
Tiago Pereira committed
268
        """
269 270 271 272
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
Tiago Pereira's avatar
Tiago Pereira committed
273
        """
274

275
        threads = []
276
        for n in range(self.train_data_shuffler.prefetch_threads):
277
            t = threading.Thread(target=self.load_and_enqueue, args=())
278 279 280 281
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
282

283
    def load_and_enqueue(self):
Tiago Pereira's avatar
Tiago Pereira committed
284
        """
285
        Injecting data in the place holder queue
286 287 288

        **Parameters**
          session: Tensorflow session
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
289

Tiago Pereira's avatar
Tiago Pereira committed
290
        """
291
        while not self.thread_pool.should_stop():
292
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
293

294 295
            data_ph = self.train_data_shuffler("data", from_queue=False)
            label_ph = self.train_data_shuffler("label", from_queue=False)
296

297 298 299 300
            feed_dict = {data_ph: train_data,
                         label_ph: train_labels}

            self.session.run(self.train_data_shuffler.enqueue_op, feed_dict=feed_dict)
301

Tiago Pereira's avatar
Tiago Pereira committed
302
    def train(self, validation_data_shuffler=None):
303
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
304 305 306 307
        Train the network:

         ** Parameters **
           validation_data_shuffler: Data shuffler for validation
308 309 310 311
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
312

313
        logger.info("Initializing !!")
314 315

        # Loading a pretrained model
Tiago Pereira's avatar
Tiago Pereira committed
316
        if self.from_scratch:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
317
            start_step = 0
Tiago Pereira's avatar
Tiago Pereira committed
318 319
        else:
            start_step = self.global_step.eval(session=self.session)
320

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
321 322
        #if isinstance(train_data_shuffler, OnlineSampling):
        #    train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)
323 324

        # Start a thread to enqueue data asynchronously, and hide I/O latency.
Tiago Pereira's avatar
Tiago Pereira committed
325 326 327 328
        if self.train_data_shuffler.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            threads = self.start_thread()
329
            time.sleep(20) # As suggested in https://stackoverflow.com/questions/39840323/benchmark-of-howto-reading-data/39842628#39842628
330 331

        # TENSOR BOARD SUMMARY
332
        self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
333 334 335
        if validation_data_shuffler is not None:
            self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'),
                                                                    self.session.graph)
Tiago Pereira's avatar
Tiago Pereira committed
336
        # Loop for
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
337
        for step in range(start_step, self.iterations):
Tiago Pereira's avatar
Tiago Pereira committed
338
            # Run fit in the graph
339
            start = time.time()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
340
            self.fit(step)
341 342 343 344 345
            end = time.time()
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
346 347
            if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
                self.compute_validation(validation_data_shuffler, step)
348

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
349 350 351
                #if self.analizer is not None:
                #    self.validation_summary_writter.add_summary(self.analizer(
                #         validation_data_shuffler, self.architecture, self.session), step)
352 353 354 355 356

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
Tiago Pereira's avatar
Tiago Pereira committed
357
                self.saver.save(self.session, path, global_step=step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
358
                #self.architecture.save(saver, path)
359 360 361 362 363 364 365 366 367

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
Tiago Pereira's avatar
Tiago Pereira committed
368
        self.saver.save(self.session, path)
369

Tiago Pereira's avatar
Tiago Pereira committed
370
        if self.train_data_shuffler.prefetch:
371 372
            # now they should definetely stop
            self.thread_pool.request_stop()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
373
            #self.thread_pool.join(threads)