Trainer.py 15.8 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8 9 10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler.OnlineSampling import OnLineSampling
16
from .learning_rate import constant
17

18
logger = bob.core.log.setup("bob.learn.tensorflow")
19

20

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
      architecture: The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`
      optimizer: One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html
      use_gpu: Use GPUs in the training
      loss: Loss
      temp_dir: The output directory

      base_learning_rate: Initial learning rate
      weight_decay:
      convergence_threshold:

      iterations: Maximum number of iterations
      snapshot: Will take a snapshot of the network at every `n` iterations
      prefetch: Use extra Threads to deal with the I/O
      analizer: Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`
      verbosity_level:

    """
44
    def __init__(self,
45 46
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
47 48
                 use_gpu=False,
                 loss=None,
49
                 temp_dir="cnn",
50

51
                 # Learning rate
52
                 learning_rate=constant(),
53

54
                 ###### training options ##########
55
                 convergence_threshold=0.01,
56
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
57 58
                 snapshot=500,
                 validation_snapshot=100,
59
                 prefetch=False,
60 61

                 ## Analizer
62
                 analizer=SoftmaxAnalizer(),
63

64 65 66
                 ### Pretrained model
                 model_from_file="",

67
                 verbosity_level=2):
68

69 70
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
71 72

        self.architecture = architecture
73
        self.optimizer_class = optimizer
74
        self.use_gpu = use_gpu
75 76 77
        self.loss = loss
        self.temp_dir = temp_dir

78
        self.learning_rate = learning_rate
79 80 81

        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
82
        self.validation_snapshot = validation_snapshot
83
        self.convergence_threshold = convergence_threshold
84
        self.prefetch = prefetch
85

86 87 88 89 90 91
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
92
        self.thread_pool = None
93 94 95 96 97

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

98 99 100 101 102
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
103
        self.global_step = None
104

105 106
        self.model_from_file = model_from_file

107 108
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
109 110 111
    def __del__(self):
        tf.reset_default_graph()

112
    def compute_graph(self, data_shuffler, prefetch=False, name="", training=True):
113
        """
114 115 116 117 118
        Computes the graph for the trainer.

        ** Parameters **

            data_shuffler: Data shuffler
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
119
            prefetch: Uses prefetch
120
            name: Name of the graph
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
121
            training: Is it a training graph?
122 123 124
        """

        # Defining place holders
125
        if prefetch:
126
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
127 128 129 130 131 132 133

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
134
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
135 136 137 138 139 140 141
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
142
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
143 144

        # Creating graphs and defining the loss
145
        network_graph = self.architecture.compute_graph(feature_batch, training=training)
146 147 148 149 150 151
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
152
        Given a data shuffler prepared the dictionary to be injected in the graph
153 154 155 156

        ** Parameters **
            data_shuffler:

157
        """
158 159
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
160 161 162 163 164

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

165 166 167 168 169 170 171 172 173 174
    def fit(self, session, step):
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

175
        if self.prefetch:
176 177
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train])
178 179 180 181 182
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train], feed_dict=feed_dict)

183 184
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
185

186
    def compute_validation(self,  session, data_shuffler, step):
187 188 189 190 191 192 193 194 195
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
196
        # Opening a new session for validation
197 198 199
        feed_dict = self.get_feed_dict(data_shuffler)
        l = session.run(self.validation_graph, feed_dict=feed_dict)

200 201 202
        if self.validation_summary_writter is None:
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
203
        summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
204 205 206
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

207 208 209 210
    def create_general_summary(self):
        """
        Creates a simple tensorboard summary with the value of the loss and learning rate
        """
211 212 213 214 215
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

216
    def start_thread(self, session):
217 218 219 220 221 222 223
        """
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
        """

224
        threads = []
225 226
        for n in range(3):
            t = threading.Thread(target=self.load_and_enqueue, args=(session,))
227 228 229 230
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
231

232 233
    def load_and_enqueue(self, session):
        """
234
        Injecting data in the place holder queue
235 236 237

        **Parameters**
          session: Tensorflow session
238
        """
239

240
        while not self.thread_pool.should_stop():
241 242
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
243

244 245 246
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

247
            session.run(self.enqueue_op, feed_dict=feed_dict)
248

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
249
    def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
250
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
251
        Create all the necessary graphs for training, validation and inference graphs
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
252
        """
253 254 255 256 257 258 259 260 261 262 263

        # Creating train graph
        self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
        tf.add_to_collection("training_graph", self.training_graph)

        # Creating inference graph
        self.architecture.compute_inference_placeholder(train_data_shuffler.deployment_shape)
        self.architecture.compute_inference_graph()
        tf.add_to_collection("inference_placeholder", self.architecture.inference_placeholder)
        tf.add_to_collection("inference_graph", self.architecture.inference_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
264
        # Creating validation graph
265 266 267 268
        if validation_data_shuffler is not None:
            self.validation_graph = self.compute_graph(validation_data_shuffler, name="validation", training=False)
            tf.add_to_collection("validation_graph", self.validation_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
        self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)

    def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
        """
        Persist the placeholders
        """

        # Persisting the placeholders
        if self.prefetch:
            batch, label = train_data_shuffler.get_placeholders_forprefetch()
        else:
            batch, label = train_data_shuffler.get_placeholders()

        tf.add_to_collection("train_placeholder_data", batch)
        tf.add_to_collection("train_placeholder_label", label)

        # Creating validation graph
        if validation_data_shuffler is not None:
            batch, label = validation_data_shuffler.get_placeholders()
            tf.add_to_collection("validation_placeholder_data", batch)
            tf.add_to_collection("validation_placeholder_label", label)

    def bootstrap_graphs_fromfile(self, session, train_data_shuffler, validation_data_shuffler):
        """
        Bootstrap all the necessary data from file
        """
        saver = self.architecture.load(session, self.model_from_file)

        # Loading training graph
        self.training_graph = tf.get_collection("training_graph")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]

        if validation_data_shuffler is not None:
            self.validation_graph = tf.get_collection("validation_graph")[0]

        self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)

        return saver

    def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
        """
        Load placeholders from file
        """

        train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
                                             tf.get_collection("train_placeholder_label")[0])

        if validation_data_shuffler is not None:
            train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
                                                 tf.get_collection("validation_placeholder_label")[0])

324 325
    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
326
        Train the network
327 328 329 330 331
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
332

333
        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
334

335 336
        config = tf.ConfigProto(log_device_placement=True,
                                gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.333))
337
        config.gpu_options.allow_growth = True
338

339 340
        # Pickle the architecture to save
        self.architecture.pickle_net(train_data_shuffler.deployment_shape)
341

342
        with tf.Session(config=config) as session:
343

344 345 346
            # Loading a pretrained model
            if self.model_from_file != "":
                logger.info("Loading pretrained model from {0}".format(self.model_from_file))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
347
                saver = self.bootstrap_graphs_fromfile(session, train_data_shuffler, validation_data_shuffler)
348
            else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
349 350
                # Bootstraping all the graphs
                self.bootstrap_graphs(train_data_shuffler, validation_data_shuffler)
351 352 353 354 355 356 357 358 359 360 361 362

                # TODO: find an elegant way to provide this as a parameter of the trainer
                self.global_step = tf.Variable(0, trainable=False)

                # Preparing the optimizer
                self.optimizer_class._learning_rate = self.learning_rate
                self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
                tf.add_to_collection("optimizer", self.optimizer)
                tf.add_to_collection("learning_rate", self.learning_rate)

                # Train summary
                self.summaries_train = self.create_general_summary()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
363
                tf.add_to_collection("summaries_train", self.summaries_train)
364 365 366 367 368

                tf.initialize_all_variables().run()

                # Original tensorflow saver object
                saver = tf.train.Saver(var_list=tf.all_variables())
369

370 371 372
            if isinstance(train_data_shuffler, OnLineSampling):
                train_data_shuffler.set_feature_extractor(self.architecture, session=session)

373
            # Start a thread to enqueue data asynchronously, and hide I/O latency.
374 375 376 377
            if self.prefetch:
                self.thread_pool = tf.train.Coordinator()
                tf.train.start_queue_runners(coord=self.thread_pool)
                threads = self.start_thread(session)
378

379
            # TENSOR BOARD SUMMARY
380
            self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph)
381
            for step in range(self.iterations):
382 383 384 385 386 387 388

                start = time.time()
                self.fit(session, step)
                end = time.time()
                summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
                self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
389 390
                # Running validation
                if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
391
                    self.compute_validation(session, validation_data_shuffler, step)
392

393 394
                    if self.analizer is not None:
                        self.validation_summary_writter.add_summary(self.analizer(
395
                             validation_data_shuffler, self.architecture, session), step)
396

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
397 398 399
                # Taking snapshot
                if step % self.snapshot == 0:
                    logger.info("Taking snapshot")
400 401
                    path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                    self.architecture.save(session, saver, path)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
402

403 404 405 406 407
            logger.info("Training finally finished")

            self.train_summary_writter.close()
            if validation_data_shuffler is not None:
                self.validation_summary_writter.close()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
408

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
409
            # Saving the final network
410 411
            path = os.path.join(self.temp_dir, 'model.ckp')
            self.architecture.save(session, saver, path)
412

413 414 415 416
            if self.prefetch:
                # now they should definetely stop
                self.thread_pool.request_stop()
                self.thread_pool.join(threads)