Trainer.py 12.3 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8 9 10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler import OnlineSampling
16
from bob.learn.tensorflow.utils.session import Session
17
from .learning_rate import constant
18

19 20 21 22 23
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

24

25 26 27 28 29 30
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
31

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
    architecture:
      The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`

    optimizer:
      One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html

    use_gpu: bool
      Use GPUs in the training

    loss: :py:class:`bob.learn.tensorflow.loss.BaseLoss`
      Loss function

    temp_dir: str
      The output directory

47
    learning_rate: `bob.learn.tensorflow.trainers.learning_rate`
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
      Initial learning rate

    convergence_threshold:

    iterations: int
      Maximum number of iterations

    snapshot: int
      Will take a snapshot of the network at every `n` iterations

    prefetch: bool
      Use extra Threads to deal with the I/O

    model_from_file: str
      If you want to use a pretrained model

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

    verbosity_level:
68 69

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
70

71
    def __init__(self,
Tiago Pereira's avatar
Tiago Pereira committed
72
                 train_data_shuffler,
73

74 75
                 ###### training options ##########
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
76 77
                 snapshot=500,
                 validation_snapshot=100,
78 79

                 ## Analizer
80
                 analizer=SoftmaxAnalizer(),
81

Tiago Pereira's avatar
Tiago Pereira committed
82 83
                 # Temporatu dir
                 temp_dir="cnn",
84

85
                 verbosity_level=2):
86

Tiago Pereira's avatar
Tiago Pereira committed
87
        self.train_data_shuffler = train_data_shuffler
88 89
        self.temp_dir = temp_dir

90 91
        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
92
        self.validation_snapshot = validation_snapshot
93

94 95 96
        # Training variables used in the fit
        self.summaries_train = None
        self.train_summary_writter = None
97
        self.thread_pool = None
98 99 100 101

        # Validation data
        self.validation_summary_writter = None

102 103
        # Analizer
        self.analizer = analizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
104
        self.global_step = None
105

106
        self.session = None
107

Tiago Pereira's avatar
Tiago Pereira committed
108 109 110 111 112 113 114 115 116 117 118
        self.graph = None
        self.loss = None
        self.predictor = None
        self.optimizer_class = None
        self.learning_rate = None
        # Training variables used in the fit
        self.optimizer = None
        self.data_ph = None
        self.label_ph = None
        self.saver = None

119 120
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago Pereira's avatar
Tiago Pereira committed
121 122 123
        # Creating the session
        self.session = Session.instance(new=True).session
        self.from_scratch = True
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
124

Tiago Pereira's avatar
Tiago Pereira committed
125 126 127 128
    def create_network_from_scratch(self,
                                    graph,
                                    optimizer=tf.train.AdamOptimizer(),
                                    loss=None,
129

Tiago Pereira's avatar
Tiago Pereira committed
130 131 132 133 134 135 136 137
                                    # Learning rate
                                    learning_rate=None,
                                    ):

        self.data_ph = self.train_data_shuffler("data")
        self.label_ph = self.train_data_shuffler("label")
        self.graph = graph
        self.loss = loss
138 139
        #self.predictor = self.loss(self.graph, self.train_data_shuffler("label", from_queue=True))
        self.predictor = self.loss(self.graph, self.train_data_shuffler("label", from_queue=False))
140

Tiago Pereira's avatar
Tiago Pereira committed
141 142
        self.optimizer_class = optimizer
        self.learning_rate = learning_rate
143

Tiago Pereira's avatar
Tiago Pereira committed
144 145
        # TODO: find an elegant way to provide this as a parameter of the trainer
        self.global_step = tf.Variable(0, trainable=False, name="global_step")
Tiago Pereira's avatar
Tiago Pereira committed
146 147 148 149

        # Saving all the variables
        self.saver = tf.train.Saver(var_list=tf.global_variables())

Tiago Pereira's avatar
Tiago Pereira committed
150
        tf.add_to_collection("global_step", self.global_step)
151

Tiago Pereira's avatar
Tiago Pereira committed
152 153
        tf.add_to_collection("graph", self.graph)
        tf.add_to_collection("predictor", self.predictor)
154

Tiago Pereira's avatar
Tiago Pereira committed
155 156
        tf.add_to_collection("data_ph", self.data_ph)
        tf.add_to_collection("label_ph", self.label_ph)
157

Tiago Pereira's avatar
Tiago Pereira committed
158 159 160 161 162
        # Preparing the optimizer
        self.optimizer_class._learning_rate = self.learning_rate
        self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
        tf.add_to_collection("optimizer", self.optimizer)
        tf.add_to_collection("learning_rate", self.learning_rate)
163

Tiago Pereira's avatar
Tiago Pereira committed
164 165
        self.summaries_train = self.create_general_summary()
        tf.add_to_collection("summaries_train", self.summaries_train)
166

Tiago Pereira's avatar
Tiago Pereira committed
167

Tiago Pereira's avatar
Tiago Pereira committed
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
        # Creating the variables
        tf.global_variables_initializer().run(session=self.session)

    def create_network_from_file(self, model_from_file):
        """
        Bootstrap all the necessary data from file

         ** Parameters **
           session: Tensorflow session
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

        """
        #saver = self.architecture.load(self.model_from_file, clear_devices=False)
        self.saver = tf.train.import_meta_graph(model_from_file + ".meta")
        self.saver.restore(self.session, model_from_file)

        # Loading training graph
Tiago Pereira's avatar
Tiago Pereira committed
186 187
        self.data_ph = tf.get_collection("data_ph")[0]
        self.label_ph = tf.get_collection("label_ph")[0]
Tiago Pereira's avatar
Tiago Pereira committed
188 189 190 191 192 193 194 195 196 197 198 199 200

        self.graph = tf.get_collection("graph")[0]
        self.predictor = tf.get_collection("predictor")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
        self.global_step = tf.get_collection("global_step")[0]
        self.from_scratch = False

    def __del__(self):
        tf.reset_default_graph()
201 202 203

    def get_feed_dict(self, data_shuffler):
        """
204
        Given a data shuffler prepared the dictionary to be injected in the graph
205 206 207 208

        ** Parameters **
            data_shuffler:

209
        """
210
        [data, labels] = data_shuffler.get_batch()
211

Tiago Pereira's avatar
Tiago Pereira committed
212 213
        feed_dict = {self.data_ph: data,
                     self.label_ph: labels}
214 215
        return feed_dict

216
    def fit(self, step):
217 218 219 220 221 222 223 224 225
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

Tiago Pereira's avatar
Tiago Pereira committed
226
        if self.train_data_shuffler.prefetch:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
227
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
228
                                                  self.learning_rate, self.summaries_train])
229 230
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
231
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
232
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
233

234 235
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
236

Tiago Pereira's avatar
Tiago Pereira committed
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
    def compute_validation(self, data_shuffler, step):
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
        pass
        # Opening a new session for validation
        #feed_dict = self.get_feed_dict(data_shuffler)
        #l, summary = self.session.run(self.predictor, self.summaries_train, feed_dict=feed_dict)
        #train_summary_writter.add_summary(summary, step)


        #summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
        #self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        #logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

258
    def create_general_summary(self):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
259
        """
260
        Creates a simple tensorboard summary with the value of the loss and learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
261
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
262

263
        # Train summary
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
264
        tf.summary.scalar('loss', self.predictor)
265 266
        tf.summary.scalar('lr', self.learning_rate)
        return tf.summary.merge_all()
267

268
    def start_thread(self):
Tiago Pereira's avatar
Tiago Pereira committed
269
        """
270 271 272 273
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
Tiago Pereira's avatar
Tiago Pereira committed
274
        """
275

276
        threads = []
277
        for n in range(3):
278
            t = threading.Thread(target=self.load_and_enqueue, args=())
279 280 281 282
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
283

284
    def load_and_enqueue(self):
Tiago Pereira's avatar
Tiago Pereira committed
285
        """
286
        Injecting data in the place holder queue
287 288 289

        **Parameters**
          session: Tensorflow session
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
290

Tiago Pereira's avatar
Tiago Pereira committed
291
        """
292
        while not self.thread_pool.should_stop():
293
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
294

Tiago Pereira's avatar
Tiago Pereira committed
295 296
            feed_dict = {self.data_ph: train_data,
                         self.label_ph: train_labels}
297

Tiago Pereira's avatar
Tiago Pereira committed
298
            self.session.run(self.inputs.enqueue_op, feed_dict=feed_dict)
299

Tiago Pereira's avatar
Tiago Pereira committed
300
    def train(self, validation_data_shuffler=None):
301
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
302 303 304 305 306 307
        Train the network:

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation
308 309 310 311
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
312

313
        logger.info("Initializing !!")
314 315

        # Loading a pretrained model
Tiago Pereira's avatar
Tiago Pereira committed
316
        if self.from_scratch:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
317
            start_step = 0
Tiago Pereira's avatar
Tiago Pereira committed
318 319
        else:
            start_step = self.global_step.eval(session=self.session)
320

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
321 322
        #if isinstance(train_data_shuffler, OnlineSampling):
        #    train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)
323 324

        # Start a thread to enqueue data asynchronously, and hide I/O latency.
Tiago Pereira's avatar
Tiago Pereira committed
325 326 327 328
        if self.train_data_shuffler.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            threads = self.start_thread()
329 330

        # TENSOR BOARD SUMMARY
331
        self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
332 333 334
        if validation_data_shuffler is not None:
            self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'),
                                                                    self.session.graph)
Tiago Pereira's avatar
Tiago Pereira committed
335
        # Loop for
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
336
        for step in range(start_step, self.iterations):
Tiago Pereira's avatar
Tiago Pereira committed
337
            # Run fit in the graph
338
            start = time.time()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
339
            self.fit(step)
340 341 342 343 344
            end = time.time()
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
345 346
            if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
                self.compute_validation(validation_data_shuffler, step)
347

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
348 349 350
                #if self.analizer is not None:
                #    self.validation_summary_writter.add_summary(self.analizer(
                #         validation_data_shuffler, self.architecture, self.session), step)
351 352 353 354 355

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
Tiago Pereira's avatar
Tiago Pereira committed
356
                self.saver.save(self.session, path, global_step=step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
357
                #self.architecture.save(saver, path)
358 359 360 361 362 363 364 365 366

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
Tiago Pereira's avatar
Tiago Pereira committed
367
        self.saver.save(self.session, path)
368

Tiago Pereira's avatar
Tiago Pereira committed
369
        if self.train_data_shuffler.prefetch:
370 371
            # now they should definetely stop
            self.thread_pool.request_stop()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
372
            #self.thread_pool.join(threads)