TripletTrainer.py 9.38 KB
Newer Older
1
2
3
4
5
6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
7
from tensorflow.core.framework import summary_pb2
8
9
10
import threading
from ..analyzers import ExperimentAnalizer
from .Trainer import Trainer
Tiago Pereira's avatar
Tiago Pereira committed
11
from ..analyzers import SoftmaxAnalizer
12
import os
Tiago Pereira's avatar
Tiago Pereira committed
13
14
from bob.learn.tensorflow.utils.session import Session
import bob.core
15
16
import logging
logger = logging.getLogger("bob.learn")
17
18
19


class TripletTrainer(Trainer):
20
    """
Tiago Pereira's avatar
Tiago Pereira committed
21
    Trainer for Triple networks:
22
23

    Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
Tiago Pereira's avatar
Tiago Pereira committed
24
    "Facenet: A unified embedding for face recognition and clustering." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
25
26

    **Parameters**
27

Tiago Pereira's avatar
Tiago Pereira committed
28
29
    train_data_shuffler:
      The data shuffler used for batching data for training
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
30

Tiago Pereira's avatar
Tiago Pereira committed
31
    iterations:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
32
      Maximum number of iterations
33

Tiago Pereira's avatar
Tiago Pereira committed
34
    snapshot:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
35
      Will take a snapshot of the network at every `n` iterations
36

Tiago Pereira's avatar
Tiago Pereira committed
37
38
    validation_snapshot:
      Test with validation each `n` iterations
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
39
40
41
42

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

Tiago Pereira's avatar
Tiago Pereira committed
43
44
45
    temp_dir: str
      The output directory

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
46
    verbosity_level:
47
48


Tiago Pereira's avatar
Tiago Pereira committed
49
    """
Tiago Pereira's avatar
Tiago Pereira committed
50

51
    def __init__(self,
Tiago Pereira's avatar
Tiago Pereira committed
52
                 train_data_shuffler,
53
54
55

                 ###### training options ##########
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
56
57
                 snapshot=500,
                 validation_snapshot=100,
58
59

                 ## Analizer
Tiago Pereira's avatar
Tiago Pereira committed
60
                 analizer=SoftmaxAnalizer(),
61

Tiago Pereira's avatar
Tiago Pereira committed
62
63
                 # Temporatu dir
                 temp_dir="siamese_cnn",
64

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
65
66
                 verbosity_level=2
                 ):
67
68

        super(TripletTrainer, self).__init__(
Tiago Pereira's avatar
Tiago Pereira committed
69
            train_data_shuffler,
70
71

            ###### training options ##########
Tiago Pereira's avatar
Tiago Pereira committed
72
73
74
            iterations=5000,
            snapshot=500,
            validation_snapshot=100,
75
76

            ## Analizer
Tiago Pereira's avatar
Tiago Pereira committed
77
            analizer=SoftmaxAnalizer(),
78

Tiago Pereira's avatar
Tiago Pereira committed
79
80
81
82
            # Temporatu dir
            temp_dir="siamese_cnn",

            verbosity_level=2
83
84
        )

Tiago Pereira's avatar
Tiago Pereira committed
85
86
        self.train_data_shuffler = train_data_shuffler
        self.temp_dir = temp_dir
87

Tiago Pereira's avatar
Tiago Pereira committed
88
89
90
        self.iterations = iterations
        self.snapshot = snapshot
        self.validation_snapshot = validation_snapshot
91

Tiago Pereira's avatar
Tiago Pereira committed
92
93
94
        # Training variables used in the fit
        self.summaries_train = None
        self.train_summary_writter = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
95

Tiago Pereira's avatar
Tiago Pereira committed
96
97
        # Validation data
        self.validation_summary_writter = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
98

Tiago Pereira's avatar
Tiago Pereira committed
99
100
101
        # Analizer
        self.analizer = analizer
        self.global_step = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
102

Tiago Pereira's avatar
Tiago Pereira committed
103
        self.session = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
104

Tiago Pereira's avatar
Tiago Pereira committed
105
106
107
108
109
110
111
112
113
114
        self.graph = None
        self.loss = None
        self.predictor = None
        self.optimizer_class = None
        self.learning_rate = None
        # Training variables used in the fit
        self.optimizer = None
        self.data_ph = None
        self.label_ph = None
        self.saver = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
115

Tiago Pereira's avatar
Tiago Pereira committed
116
        bob.core.log.set_verbosity_level(logger, verbosity_level)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
117

Tiago Pereira's avatar
Tiago Pereira committed
118
119
120
        # Creating the session
        self.session = Session.instance(new=True).session
        self.from_scratch = True
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
121

Tiago Pereira's avatar
Tiago Pereira committed
122
        bob.core.log.set_verbosity_level(logger, verbosity_level)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
123

Tiago Pereira's avatar
Tiago Pereira committed
124
125
126
127
    def create_network_from_scratch(self,
                                    graph,
                                    optimizer=tf.train.AdamOptimizer(),
                                    loss=None,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
128

Tiago Pereira's avatar
Tiago Pereira committed
129
130
131
                                    # Learning rate
                                    learning_rate=None,
                                    ):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
132

Tiago Pereira's avatar
Tiago Pereira committed
133
        self.data_ph = self.train_data_shuffler("data")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
134

Tiago Pereira's avatar
Tiago Pereira committed
135
136
137
        self.graph = graph
        if "anchor" and "positive" and "negative" not in self.graph:
            raise ValueError("`graph` should be a dictionary with two elements (`anchor`, `positive` and `negative`)")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
138

Tiago Pereira's avatar
Tiago Pereira committed
139
140
141
142
143
144
        self.loss = loss
        self.predictor = self.loss(self.graph["anchor"],
                                   self.graph["positive"],
                                   self.graph["negative"])
        self.optimizer_class = optimizer
        self.learning_rate = learning_rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
145

146
        self.global_step = tf.contrib.framework.get_or_create_global_step()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
147

Tiago Pereira's avatar
Tiago Pereira committed
148
149
        # Saving all the variables
        self.saver = tf.train.Saver(var_list=tf.global_variables())
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
150

Tiago Pereira's avatar
Tiago Pereira committed
151
        tf.add_to_collection("global_step", self.global_step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
152

Tiago Pereira's avatar
Tiago Pereira committed
153
154
155
156
        # Saving the pointers to the graph
        tf.add_to_collection("graph_anchor", self.graph['anchor'])
        tf.add_to_collection("graph_positive", self.graph['positive'])
        tf.add_to_collection("graph_negative", self.graph['negative'])
157

Tiago Pereira's avatar
Tiago Pereira committed
158
159
160
161
162
163
164
165
166
        # Saving pointers to the loss
        tf.add_to_collection("predictor_loss", self.predictor['loss'])
        tf.add_to_collection("predictor_between_class_loss", self.predictor['between_class'])
        tf.add_to_collection("predictor_within_class_loss", self.predictor['within_class'])

        # Saving the pointers to the placeholders
        tf.add_to_collection("data_ph_anchor", self.data_ph['anchor'])
        tf.add_to_collection("data_ph_positive", self.data_ph['positive'])
        tf.add_to_collection("data_ph_negative", self.data_ph['negative'])
167

Tiago Pereira's avatar
Tiago Pereira committed
168
169
        # Preparing the optimizer
        self.optimizer_class._learning_rate = self.learning_rate
Tiago Pereira's avatar
Tiago Pereira committed
170
        self.optimizer = self.optimizer_class.minimize(self.predictor['loss'], global_step=self.global_step)
Tiago Pereira's avatar
Tiago Pereira committed
171
172
        tf.add_to_collection("optimizer", self.optimizer)
        tf.add_to_collection("learning_rate", self.learning_rate)
173

Tiago Pereira's avatar
Tiago Pereira committed
174
175
        self.summaries_train = self.create_general_summary()
        tf.add_to_collection("summaries_train", self.summaries_train)
176

Tiago Pereira's avatar
Tiago Pereira committed
177
178
        # Creating the variables
        tf.global_variables_initializer().run(session=self.session)
179

180
    def create_network_from_file(self, model_from_file, clear_devices=True):
181
182

        #saver = self.architecture.load(self.model_from_file, clear_devices=False)
183
        self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices)
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        self.saver.restore(self.session, model_from_file)

        # Loading the graph from the graph pointers
        self.graph = dict()
        self.graph['anchor'] = tf.get_collection("graph_anchor")[0]
        self.graph['positive'] = tf.get_collection("graph_positive")[0]
        self.graph['negative'] = tf.get_collection("graph_negative")[0]

        # Loading the placeholders from the pointers
        self.data_ph = dict()
        self.data_ph['anchor'] = tf.get_collection("data_ph_anchor")[0]
        self.data_ph['positive'] = tf.get_collection("data_ph_positive")[0]
        self.data_ph['negative'] = tf.get_collection("data_ph_negative")[0]

        # Loading loss from the pointers
        self.predictor = dict()
        self.predictor['loss'] = tf.get_collection("predictor_loss")[0]
        self.predictor['between_class'] = tf.get_collection("predictor_between_class_loss")[0]
        self.predictor['within_class'] = tf.get_collection("predictor_within_class_loss")[0]

        # Loading other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
        self.global_step = tf.get_collection("global_step")[0]
        self.from_scratch = False

211
    def get_feed_dict(self, data_shuffler):
212

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
213
        [batch_anchor, batch_positive, batch_negative] = data_shuffler.get_batch()
Tiago Pereira's avatar
Tiago Pereira committed
214
215
216
        feed_dict = {self.data_ph['anchor']: batch_anchor,
                     self.data_ph['positive']: batch_positive,
                     self.data_ph['negative']: batch_negative}
217

218
        return feed_dict
219

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
220
    def fit(self, step):
Tiago Pereira's avatar
Tiago Pereira committed
221
222
223
        feed_dict = self.get_feed_dict(self.train_data_shuffler)
        _, l, bt_class, wt_class, lr, summary = self.session.run([
                                                self.optimizer,
Tiago Pereira's avatar
Tiago Pereira committed
224
225
                                                self.predictor['loss'], self.predictor['between_class'],
                                                self.predictor['within_class'],
Tiago Pereira's avatar
Tiago Pereira committed
226
                                                self.learning_rate, self.summaries_train], feed_dict=feed_dict)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
227

228
229
230
231
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)

    def create_general_summary(self):
232

233
        # Train summary
Tiago Pereira's avatar
Tiago Pereira committed
234
235
236
        tf.summary.scalar('loss', self.predictor['loss'])
        tf.summary.scalar('between_class_loss', self.predictor['between_class'])
        tf.summary.scalar('within_class_loss', self.predictor['within_class'])
237
238
        tf.summary.scalar('lr', self.learning_rate)
        return tf.summary.merge_all()
239

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    def load_and_enqueue(self):
        """
        Injecting data in the place holder queue

        **Parameters**
          session: Tensorflow session

        """
        while not self.thread_pool.should_stop():
            [train_data_anchor, train_data_positive, train_data_negative] = self.train_data_shuffler.get_batch()

            data_ph = dict()
            data_ph['anchor'] = self.train_data_shuffler("data", from_queue=False)['anchor']
            data_ph['positive'] = self.train_data_shuffler("data", from_queue=False)['positive']
            data_ph['negative'] = self.train_data_shuffler("data", from_queue=False)['negative']

            feed_dict = {data_ph['anchor']: train_data_anchor,
                         data_ph['positive']: train_data_positive,
                         data_ph['negative']: train_data_negative}

            self.session.run(self.train_data_shuffler.enqueue_op, feed_dict=feed_dict)