Trainer.py 11.8 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
7 8 9
import threading
import os
import bob.io.base
10
import bob.core
11
from ..analyzers import SoftmaxAnalizer
12
from tensorflow.core.framework import summary_pb2
13
import time
14
from bob.learn.tensorflow.datashuffler import OnlineSampling
15
from bob.learn.tensorflow.utils.session import Session
16
from .learning_rate import constant
17

18 19 20 21 22
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

23

24 25 26 27 28 29
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
30

Tiago Pereira's avatar
Tiago Pereira committed
31 32
    train_data_shuffler:
      The data shuffler used for batching data for training
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
33

Tiago Pereira's avatar
Tiago Pereira committed
34
    iterations:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
35
      Maximum number of iterations
Tiago Pereira's avatar
Tiago Pereira committed
36 37
      
    snapshot:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
38
      Will take a snapshot of the network at every `n` iterations
Tiago Pereira's avatar
Tiago Pereira committed
39 40 41
      
    validation_snapshot:
      Test with validation each `n` iterations
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
42 43 44 45

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

Tiago Pereira's avatar
Tiago Pereira committed
46 47 48
    temp_dir: str
      The output directory

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
49
    verbosity_level:
50 51

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
52

53
    def __init__(self,
Tiago Pereira's avatar
Tiago Pereira committed
54
                 train_data_shuffler,
55

56 57
                 ###### training options ##########
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
58 59
                 snapshot=500,
                 validation_snapshot=100,
60 61

                 ## Analizer
62
                 analizer=SoftmaxAnalizer(),
63

Tiago Pereira's avatar
Tiago Pereira committed
64 65
                 # Temporatu dir
                 temp_dir="cnn",
66

67
                 verbosity_level=2):
68

Tiago Pereira's avatar
Tiago Pereira committed
69
        self.train_data_shuffler = train_data_shuffler
70 71
        self.temp_dir = temp_dir

72 73
        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
74
        self.validation_snapshot = validation_snapshot
75

76 77 78
        # Training variables used in the fit
        self.summaries_train = None
        self.train_summary_writter = None
79
        self.thread_pool = None
80 81 82 83

        # Validation data
        self.validation_summary_writter = None

84 85
        # Analizer
        self.analizer = analizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
86
        self.global_step = None
87

88
        self.session = None
89

Tiago Pereira's avatar
Tiago Pereira committed
90 91 92 93 94 95 96 97 98 99 100
        self.graph = None
        self.loss = None
        self.predictor = None
        self.optimizer_class = None
        self.learning_rate = None
        # Training variables used in the fit
        self.optimizer = None
        self.data_ph = None
        self.label_ph = None
        self.saver = None

101 102
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago Pereira's avatar
Tiago Pereira committed
103 104 105
        # Creating the session
        self.session = Session.instance(new=True).session
        self.from_scratch = True
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
106

Tiago Pereira's avatar
Tiago Pereira committed
107 108 109 110
    def create_network_from_scratch(self,
                                    graph,
                                    optimizer=tf.train.AdamOptimizer(),
                                    loss=None,
111

Tiago Pereira's avatar
Tiago Pereira committed
112 113 114 115 116 117 118 119
                                    # Learning rate
                                    learning_rate=None,
                                    ):

        self.data_ph = self.train_data_shuffler("data")
        self.label_ph = self.train_data_shuffler("label")
        self.graph = graph
        self.loss = loss
120 121
        #self.predictor = self.loss(self.graph, self.train_data_shuffler("label", from_queue=True))
        self.predictor = self.loss(self.graph, self.train_data_shuffler("label", from_queue=False))
122

Tiago Pereira's avatar
Tiago Pereira committed
123 124
        self.optimizer_class = optimizer
        self.learning_rate = learning_rate
125

Tiago Pereira's avatar
Tiago Pereira committed
126 127
        # TODO: find an elegant way to provide this as a parameter of the trainer
        self.global_step = tf.Variable(0, trainable=False, name="global_step")
Tiago Pereira's avatar
Tiago Pereira committed
128 129 130 131

        # Saving all the variables
        self.saver = tf.train.Saver(var_list=tf.global_variables())

Tiago Pereira's avatar
Tiago Pereira committed
132
        tf.add_to_collection("global_step", self.global_step)
133

Tiago Pereira's avatar
Tiago Pereira committed
134 135
        tf.add_to_collection("graph", self.graph)
        tf.add_to_collection("predictor", self.predictor)
136

Tiago Pereira's avatar
Tiago Pereira committed
137 138
        tf.add_to_collection("data_ph", self.data_ph)
        tf.add_to_collection("label_ph", self.label_ph)
139

Tiago Pereira's avatar
Tiago Pereira committed
140 141 142 143 144
        # Preparing the optimizer
        self.optimizer_class._learning_rate = self.learning_rate
        self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
        tf.add_to_collection("optimizer", self.optimizer)
        tf.add_to_collection("learning_rate", self.learning_rate)
145

Tiago Pereira's avatar
Tiago Pereira committed
146 147
        self.summaries_train = self.create_general_summary()
        tf.add_to_collection("summaries_train", self.summaries_train)
148

Tiago Pereira's avatar
Tiago Pereira committed
149

Tiago Pereira's avatar
Tiago Pereira committed
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
        # Creating the variables
        tf.global_variables_initializer().run(session=self.session)

    def create_network_from_file(self, model_from_file):
        """
        Bootstrap all the necessary data from file

         ** Parameters **
           session: Tensorflow session
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

        """
        #saver = self.architecture.load(self.model_from_file, clear_devices=False)
        self.saver = tf.train.import_meta_graph(model_from_file + ".meta")
        self.saver.restore(self.session, model_from_file)

        # Loading training graph
Tiago Pereira's avatar
Tiago Pereira committed
168 169
        self.data_ph = tf.get_collection("data_ph")[0]
        self.label_ph = tf.get_collection("label_ph")[0]
Tiago Pereira's avatar
Tiago Pereira committed
170 171 172 173 174 175 176 177 178 179 180 181 182

        self.graph = tf.get_collection("graph")[0]
        self.predictor = tf.get_collection("predictor")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
        self.global_step = tf.get_collection("global_step")[0]
        self.from_scratch = False

    def __del__(self):
        tf.reset_default_graph()
183 184 185

    def get_feed_dict(self, data_shuffler):
        """
186
        Given a data shuffler prepared the dictionary to be injected in the graph
187 188 189 190

        ** Parameters **
            data_shuffler:

191
        """
192
        [data, labels] = data_shuffler.get_batch()
193

Tiago Pereira's avatar
Tiago Pereira committed
194 195
        feed_dict = {self.data_ph: data,
                     self.label_ph: labels}
196 197
        return feed_dict

198
    def fit(self, step):
199 200 201 202 203 204 205 206 207
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

Tiago Pereira's avatar
Tiago Pereira committed
208
        if self.train_data_shuffler.prefetch:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
209
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
210
                                                  self.learning_rate, self.summaries_train])
211 212
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
213
            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
214
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
215

216 217
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
218

Tiago Pereira's avatar
Tiago Pereira committed
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
    def compute_validation(self, data_shuffler, step):
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
        pass
        # Opening a new session for validation
        #feed_dict = self.get_feed_dict(data_shuffler)
        #l, summary = self.session.run(self.predictor, self.summaries_train, feed_dict=feed_dict)
        #train_summary_writter.add_summary(summary, step)


        #summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
        #self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        #logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

240
    def create_general_summary(self):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
241
        """
242
        Creates a simple tensorboard summary with the value of the loss and learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
243
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
244

245
        # Train summary
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
246
        tf.summary.scalar('loss', self.predictor)
247 248
        tf.summary.scalar('lr', self.learning_rate)
        return tf.summary.merge_all()
249

250
    def start_thread(self):
Tiago Pereira's avatar
Tiago Pereira committed
251
        """
252 253 254 255
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
Tiago Pereira's avatar
Tiago Pereira committed
256
        """
257

258
        threads = []
259
        for n in range(3):
260
            t = threading.Thread(target=self.load_and_enqueue, args=())
261 262 263 264
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
265

266
    def load_and_enqueue(self):
Tiago Pereira's avatar
Tiago Pereira committed
267
        """
268
        Injecting data in the place holder queue
269 270 271

        **Parameters**
          session: Tensorflow session
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
272

Tiago Pereira's avatar
Tiago Pereira committed
273
        """
274
        while not self.thread_pool.should_stop():
275
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
276

Tiago Pereira's avatar
Tiago Pereira committed
277 278
            feed_dict = {self.data_ph: train_data,
                         self.label_ph: train_labels}
279

Tiago Pereira's avatar
Tiago Pereira committed
280
            self.session.run(self.inputs.enqueue_op, feed_dict=feed_dict)
281

Tiago Pereira's avatar
Tiago Pereira committed
282
    def train(self, validation_data_shuffler=None):
283
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
284 285 286 287 288 289
        Train the network:

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation
290 291 292 293
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
294

295
        logger.info("Initializing !!")
296 297

        # Loading a pretrained model
Tiago Pereira's avatar
Tiago Pereira committed
298
        if self.from_scratch:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
299
            start_step = 0
Tiago Pereira's avatar
Tiago Pereira committed
300 301
        else:
            start_step = self.global_step.eval(session=self.session)
302

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
303 304
        #if isinstance(train_data_shuffler, OnlineSampling):
        #    train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)
305 306

        # Start a thread to enqueue data asynchronously, and hide I/O latency.
Tiago Pereira's avatar
Tiago Pereira committed
307 308 309 310
        if self.train_data_shuffler.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            threads = self.start_thread()
311 312

        # TENSOR BOARD SUMMARY
313
        self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
314 315 316
        if validation_data_shuffler is not None:
            self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'),
                                                                    self.session.graph)
Tiago Pereira's avatar
Tiago Pereira committed
317
        # Loop for
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
318
        for step in range(start_step, self.iterations):
Tiago Pereira's avatar
Tiago Pereira committed
319
            # Run fit in the graph
320
            start = time.time()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
321
            self.fit(step)
322 323 324 325 326
            end = time.time()
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
327 328
            if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
                self.compute_validation(validation_data_shuffler, step)
329

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
330 331 332
                #if self.analizer is not None:
                #    self.validation_summary_writter.add_summary(self.analizer(
                #         validation_data_shuffler, self.architecture, self.session), step)
333 334 335 336 337

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
Tiago Pereira's avatar
Tiago Pereira committed
338
                self.saver.save(self.session, path, global_step=step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
339
                #self.architecture.save(saver, path)
340 341 342 343 344 345 346 347 348

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
Tiago Pereira's avatar
Tiago Pereira committed
349
        self.saver.save(self.session, path)
350

Tiago Pereira's avatar
Tiago Pereira committed
351
        if self.train_data_shuffler.prefetch:
352 353
            # now they should definetely stop
            self.thread_pool.request_stop()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
354
            #self.thread_pool.join(threads)