Trainer.py 12.2 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8 9 10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler import OnlineSampling
16
from bob.learn.tensorflow.utils.session import Session
17
from .learning_rate import constant
18

19 20 21 22 23
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

24

25 26 27 28 29 30
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
31

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
    architecture:
      The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`

    optimizer:
      One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html

    use_gpu: bool
      Use GPUs in the training

    loss: :py:class:`bob.learn.tensorflow.loss.BaseLoss`
      Loss function

    temp_dir: str
      The output directory

47
    learning_rate: `bob.learn.tensorflow.trainers.learning_rate`
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
      Initial learning rate

    convergence_threshold:

    iterations: int
      Maximum number of iterations

    snapshot: int
      Will take a snapshot of the network at every `n` iterations

    prefetch: bool
      Use extra Threads to deal with the I/O

    model_from_file: str
      If you want to use a pretrained model

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

    verbosity_level:
68 69

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
70

71
    def __init__(self,
Tiago Pereira's avatar
Tiago Pereira committed
72
                 train_data_shuffler,
73

74 75
                 ###### training options ##########
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
76 77
                 snapshot=500,
                 validation_snapshot=100,
78 79

                 ## Analizer
80
                 analizer=SoftmaxAnalizer(),
81

Tiago Pereira's avatar
Tiago Pereira committed
82 83
                 # Temporatu dir
                 temp_dir="cnn",
84

85
                 verbosity_level=2):
86

Tiago Pereira's avatar
Tiago Pereira committed
87
        self.train_data_shuffler = train_data_shuffler
88 89
        self.temp_dir = temp_dir

90 91
        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
92
        self.validation_snapshot = validation_snapshot
93

94 95 96
        # Training variables used in the fit
        self.summaries_train = None
        self.train_summary_writter = None
97
        self.thread_pool = None
98 99 100 101

        # Validation data
        self.validation_summary_writter = None

102 103
        # Analizer
        self.analizer = analizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
104
        self.global_step = None
105

106
        self.session = None
107

Tiago Pereira's avatar
Tiago Pereira committed
108 109 110 111 112 113 114 115 116 117 118
        self.graph = None
        self.loss = None
        self.predictor = None
        self.optimizer_class = None
        self.learning_rate = None
        # Training variables used in the fit
        self.optimizer = None
        self.data_ph = None
        self.label_ph = None
        self.saver = None

119 120
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago Pereira's avatar
Tiago Pereira committed
121 122 123
        # Creating the session
        self.session = Session.instance(new=True).session
        self.from_scratch = True
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
124

Tiago Pereira's avatar
Tiago Pereira committed
125 126 127 128
    def create_network_from_scratch(self,
                                    graph,
                                    optimizer=tf.train.AdamOptimizer(),
                                    loss=None,
129

Tiago Pereira's avatar
Tiago Pereira committed
130 131 132 133 134 135 136 137 138
                                    # Learning rate
                                    learning_rate=None,
                                    ):

        self.data_ph = self.train_data_shuffler("data")
        self.label_ph = self.train_data_shuffler("label")
        self.graph = graph
        self.loss = loss
        self.predictor = self.loss(self.graph, self.train_data_shuffler("label", from_queue=True))
139

Tiago Pereira's avatar
Tiago Pereira committed
140 141
        self.optimizer_class = optimizer
        self.learning_rate = learning_rate
142

Tiago Pereira's avatar
Tiago Pereira committed
143 144
        # TODO: find an elegant way to provide this as a parameter of the trainer
        self.global_step = tf.Variable(0, trainable=False, name="global_step")
Tiago Pereira's avatar
Tiago Pereira committed
145 146 147 148

        # Saving all the variables
        self.saver = tf.train.Saver(var_list=tf.global_variables())

Tiago Pereira's avatar
Tiago Pereira committed
149
        tf.add_to_collection("global_step", self.global_step)
150

Tiago Pereira's avatar
Tiago Pereira committed
151 152
        tf.add_to_collection("graph", self.graph)
        tf.add_to_collection("predictor", self.predictor)
153

Tiago Pereira's avatar
Tiago Pereira committed
154 155
        tf.add_to_collection("data_ph", self.data_ph)
        tf.add_to_collection("label_ph", self.label_ph)
156

Tiago Pereira's avatar
Tiago Pereira committed
157 158 159 160 161
        # Preparing the optimizer
        self.optimizer_class._learning_rate = self.learning_rate
        self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
        tf.add_to_collection("optimizer", self.optimizer)
        tf.add_to_collection("learning_rate", self.learning_rate)
162

Tiago Pereira's avatar
Tiago Pereira committed
163 164
        self.summaries_train = self.create_general_summary()
        tf.add_to_collection("summaries_train", self.summaries_train)
165

Tiago Pereira's avatar
Tiago Pereira committed
166

Tiago Pereira's avatar
Tiago Pereira committed
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
        # Creating the variables
        tf.global_variables_initializer().run(session=self.session)

    def create_network_from_file(self, model_from_file):
        """
        Bootstrap all the necessary data from file

         ** Parameters **
           session: Tensorflow session
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

        """
        #saver = self.architecture.load(self.model_from_file, clear_devices=False)
        self.saver = tf.train.import_meta_graph(model_from_file + ".meta")
        self.saver.restore(self.session, model_from_file)

        # Loading training graph
Tiago Pereira's avatar
Tiago Pereira committed
185 186
        self.data_ph = tf.get_collection("data_ph")[0]
        self.label_ph = tf.get_collection("label_ph")[0]
Tiago Pereira's avatar
Tiago Pereira committed
187 188 189 190 191 192 193 194 195 196 197 198 199

        self.graph = tf.get_collection("graph")[0]
        self.predictor = tf.get_collection("predictor")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
        self.global_step = tf.get_collection("global_step")[0]
        self.from_scratch = False

    def __del__(self):
        tf.reset_default_graph()
200 201 202

    def get_feed_dict(self, data_shuffler):
        """
203
        Given a data shuffler prepared the dictionary to be injected in the graph
204 205 206 207

        ** Parameters **
            data_shuffler:

208
        """
209
        [data, labels] = data_shuffler.get_batch()
210

Tiago Pereira's avatar
Tiago Pereira committed
211 212
        feed_dict = {self.data_ph: data,
                     self.label_ph: labels}
213 214
        return feed_dict

215
    def fit(self, step):
216 217 218 219 220 221 222 223 224
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

Tiago Pereira's avatar
Tiago Pereira committed
225
        if self.train_data_shuffler.prefetch:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
226
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
227
                                                  self.learning_rate, self.summaries_train])
228 229
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
230
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
231
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
232

233 234
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
235

Tiago Pereira's avatar
Tiago Pereira committed
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
    def compute_validation(self, data_shuffler, step):
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
        pass
        # Opening a new session for validation
        #feed_dict = self.get_feed_dict(data_shuffler)
        #l, summary = self.session.run(self.predictor, self.summaries_train, feed_dict=feed_dict)
        #train_summary_writter.add_summary(summary, step)


        #summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
        #self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        #logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

257
    def create_general_summary(self):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
258
        """
259
        Creates a simple tensorboard summary with the value of the loss and learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
260
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
261

262
        # Train summary
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
263
        tf.summary.scalar('loss', self.predictor)
264 265
        tf.summary.scalar('lr', self.learning_rate)
        return tf.summary.merge_all()
266

267
    def start_thread(self):
Tiago Pereira's avatar
Tiago Pereira committed
268
        """
269 270 271 272
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
Tiago Pereira's avatar
Tiago Pereira committed
273
        """
274

275
        threads = []
276
        for n in range(3):
277
            t = threading.Thread(target=self.load_and_enqueue, args=())
278 279 280 281
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
282

283
    def load_and_enqueue(self):
Tiago Pereira's avatar
Tiago Pereira committed
284
        """
285
        Injecting data in the place holder queue
286 287 288

        **Parameters**
          session: Tensorflow session
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
289

Tiago Pereira's avatar
Tiago Pereira committed
290
        """
291
        while not self.thread_pool.should_stop():
292
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
293

Tiago Pereira's avatar
Tiago Pereira committed
294 295
            feed_dict = {self.data_ph: train_data,
                         self.label_ph: train_labels}
296

Tiago Pereira's avatar
Tiago Pereira committed
297
            self.session.run(self.inputs.enqueue_op, feed_dict=feed_dict)
298

Tiago Pereira's avatar
Tiago Pereira committed
299
    def train(self, validation_data_shuffler=None):
300
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
301 302 303 304 305 306
        Train the network:

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation
307 308 309 310
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
311

312
        logger.info("Initializing !!")
313 314

        # Loading a pretrained model
Tiago Pereira's avatar
Tiago Pereira committed
315
        if self.from_scratch:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
316
            start_step = 0
Tiago Pereira's avatar
Tiago Pereira committed
317 318
        else:
            start_step = self.global_step.eval(session=self.session)
319

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
320 321
        #if isinstance(train_data_shuffler, OnlineSampling):
        #    train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)
322 323

        # Start a thread to enqueue data asynchronously, and hide I/O latency.
Tiago Pereira's avatar
Tiago Pereira committed
324 325 326 327
        if self.train_data_shuffler.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            threads = self.start_thread()
328 329

        # TENSOR BOARD SUMMARY
330
        self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
331 332 333
        if validation_data_shuffler is not None:
            self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'),
                                                                    self.session.graph)
Tiago Pereira's avatar
Tiago Pereira committed
334
        # Loop for
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
335
        for step in range(start_step, self.iterations):
Tiago Pereira's avatar
Tiago Pereira committed
336
            # Run fit in the graph
337
            start = time.time()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
338
            self.fit(step)
339 340 341 342 343
            end = time.time()
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
344 345
            if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
                self.compute_validation(validation_data_shuffler, step)
346

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
347 348 349
                #if self.analizer is not None:
                #    self.validation_summary_writter.add_summary(self.analizer(
                #         validation_data_shuffler, self.architecture, self.session), step)
350 351 352 353 354

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
Tiago Pereira's avatar
Tiago Pereira committed
355
                self.saver.save(self.session, path, global_step=step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
356
                #self.architecture.save(saver, path)
357 358 359 360 361 362 363 364 365

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
Tiago Pereira's avatar
Tiago Pereira committed
366
        self.saver.save(self.session, path)
367

Tiago Pereira's avatar
Tiago Pereira committed
368
        if self.train_data_shuffler.prefetch:
369 370
            # now they should definetely stop
            self.thread_pool.request_stop()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
371
            #self.thread_pool.join(threads)