Trainer.py 12.8 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8 9 10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler import OnlineSampling
16
from bob.learn.tensorflow.utils.session import Session
17
from .learning_rate import constant
18

19 20 21 22 23
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

24

25 26 27 28 29 30
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
31

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
    architecture:
      The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`

    optimizer:
      One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html

    use_gpu: bool
      Use GPUs in the training

    loss: :py:class:`bob.learn.tensorflow.loss.BaseLoss`
      Loss function

    temp_dir: str
      The output directory

47
    learning_rate: `bob.learn.tensorflow.trainers.learning_rate`
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
      Initial learning rate

    convergence_threshold:

    iterations: int
      Maximum number of iterations

    snapshot: int
      Will take a snapshot of the network at every `n` iterations

    prefetch: bool
      Use extra Threads to deal with the I/O

    model_from_file: str
      If you want to use a pretrained model

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

    verbosity_level:
68 69

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
70

71
    def __init__(self,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
72
                 graph,
73
                 optimizer=tf.train.AdamOptimizer(),
74 75
                 use_gpu=False,
                 loss=None,
76
                 temp_dir="cnn",
77

78
                 # Learning rate
79
                 learning_rate=None,
80

81
                 ###### training options ##########
82
                 convergence_threshold=0.01,
83
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
84 85
                 snapshot=500,
                 validation_snapshot=100,
86
                 prefetch=False,
87 88

                 ## Analizer
89
                 analizer=SoftmaxAnalizer(),
90

91 92 93
                 ### Pretrained model
                 model_from_file="",

94
                 verbosity_level=2):
95

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
96 97
        #if not isinstance(graph, SequenceNetwork):
        #    raise ValueError("`architecture` should be instance of `SequenceNetwork`")
98

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
99
        self.graph = graph
100
        self.optimizer_class = optimizer
101
        self.use_gpu = use_gpu
102 103 104
        self.loss = loss
        self.temp_dir = temp_dir

105 106 107 108
        if learning_rate is None and model_from_file == "":
            self.learning_rate = constant()
        else:
            self.learning_rate = learning_rate
109 110 111

        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
112
        self.validation_snapshot = validation_snapshot
113
        self.convergence_threshold = convergence_threshold
114
        self.prefetch = prefetch
115

116 117 118 119 120 121
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
122
        self.thread_pool = None
123 124 125 126 127

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

128 129 130 131 132
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
133
        self.global_step = None
134

135
        self.model_from_file = model_from_file
136
        self.session = None
137

138 139
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
140 141 142
    def __del__(self):
        tf.reset_default_graph()

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
143
    """
144
    def compute_graph(self, data_shuffler, prefetch=False, name="", training=True):
145 146 147 148 149
        Computes the graph for the trainer.

        ** Parameters **

            data_shuffler: Data shuffler
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
150
            prefetch: Uses prefetch
151
            name: Name of the graph
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
152
            training: Is it a training graph?
153 154

        # Defining place holders
155
        if prefetch:
156
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
157 158 159 160 161 162 163

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
164
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
165 166 167 168 169 170 171
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
172
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
173 174

        # Creating graphs and defining the loss
175
        network_graph = self.architecture.compute_graph(feature_batch, training=training)
176 177 178
        graph = self.loss(network_graph, label_batch)

        return graph
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
179
    """
180 181 182

    def get_feed_dict(self, data_shuffler):
        """
183
        Given a data shuffler prepared the dictionary to be injected in the graph
184 185 186 187

        ** Parameters **
            data_shuffler:

188
        """
189 190
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
191 192 193 194 195

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

196
    def fit(self, step):
197 198 199 200 201 202 203 204 205
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

206
        if self.prefetch:
207
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
208
                                                  self.learning_rate, self.summaries_train])
209 210
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
211
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
212
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
213

214 215
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
216

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
217
    """
218
    def create_general_summary(self):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
219

220
        Creates a simple tensorboard summary with the value of the loss and learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
221

222
        # Train summary
223 224 225
        tf.summary.scalar('loss', self.training_graph)
        tf.summary.scalar('lr', self.learning_rate)
        return tf.summary.merge_all()
226

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
227

228
    def start_thread(self):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
229

230 231 232 233
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
234

235

236
        threads = []
237
        for n in range(3):
238
            t = threading.Thread(target=self.load_and_enqueue, args=())
239 240 241 242
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
243

244
    def load_and_enqueue(self):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
245

246
        Injecting data in the place holder queue
247 248 249

        **Parameters**
          session: Tensorflow session
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
250

251

252
        while not self.thread_pool.should_stop():
253 254
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
255

256 257 258
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

259
            self.session.run(self.enqueue_op, feed_dict=feed_dict)
260

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
261
    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
262

263
    def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
264 265
        """
        Bootstrap all the necessary data from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
266 267 268 269 270 271 272

         ** Parameters **
           session: Tensorflow session
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
273
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
274
        saver = self.architecture.load(self.model_from_file, clear_devices=False)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
275 276 277 278 279 280 281 282

        # Loading training graph
        self.training_graph = tf.get_collection("training_graph")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
283
        self.global_step = tf.get_collection("global_step")[0]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
284 285 286 287 288 289 290 291

        if validation_data_shuffler is not None:
            self.validation_graph = tf.get_collection("validation_graph")[0]

        self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)

        return saver

292 293
    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
294 295 296 297 298 299
        Train the network:

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation
300 301 302 303 304
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
305

306
        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
307

308
        # Pickle the architecture to save
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
309 310 311 312
        #self.architecture.pickle_net(train_data_shuffler.deployment_shape)

        if not isinstance(tf.Tensor, self.graph):
            raise NotImplemented("Not tensor still not implemented")
313

314
        self.session = Session.instance(new=True).session
315 316 317 318

        # Loading a pretrained model
        if self.model_from_file != "":
            logger.info("Loading pretrained model from {0}".format(self.model_from_file))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
319
            saver = self.bootstrap_graphs_fromfile(train_data_shuffler, validation_data_shuffler)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
320

321
            start_step = self.global_step.eval(session=self.session)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
322

323
        else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
324
            start_step = 0
325 326

            # TODO: find an elegant way to provide this as a parameter of the trainer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
327
            self.global_step = tf.Variable(0, trainable=False, name="global_step")
328
            tf.add_to_collection("global_step", self.global_step)
329 330 331 332 333 334 335 336

            # Preparing the optimizer
            self.optimizer_class._learning_rate = self.learning_rate
            self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
            tf.add_to_collection("optimizer", self.optimizer)
            tf.add_to_collection("learning_rate", self.learning_rate)

            # Train summary
337
            tf.global_variables_initializer().run(session=self.session)
338 339

            # Original tensorflow saver object
340
            saver = tf.train.Saver(var_list=tf.global_variables())
341

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
342 343
        #if isinstance(train_data_shuffler, OnlineSampling):
        #    train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)
344 345

        # Start a thread to enqueue data asynchronously, and hide I/O latency.
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
346 347 348 349
        #if self.prefetch:
        #    self.thread_pool = tf.train.Coordinator()
        #    tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
        #    threads = self.start_thread()
350 351

        # TENSOR BOARD SUMMARY
352
        self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
353
        for step in range(start_step, self.iterations):
354
            start = time.time()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
355
            self.fit(step)
356 357 358 359 360
            end = time.time()
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
361 362
            #if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
            #    self.compute_validation(validation_data_shuffler, step)
363

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
364 365 366
            #    if self.analizer is not None:
            #        self.validation_summary_writter.add_summary(self.analizer(
            #             validation_data_shuffler, self.architecture, self.session), step)
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                self.architecture.save(saver, path)

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
        self.architecture.save(saver, path)

        if self.prefetch:
            # now they should definetely stop
            self.thread_pool.request_stop()
            self.thread_pool.join(threads)