SiameseTrainer.py 7.95 KB
Newer Older
1
2
3
4
5
6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
7
from tensorflow.core.framework import summary_pb2
8
from .Trainer import Trainer
Tiago Pereira's avatar
Tiago Pereira committed
9
from ..analyzers import SoftmaxAnalizer
10
from .learning_rate import constant
11
import os
12
import logging
Tiago Pereira's avatar
Tiago Pereira committed
13
14
from bob.learn.tensorflow.utils.session import Session
import bob.core
15
logger = logging.getLogger("bob.learn")
16

17

18
class SiameseTrainer(Trainer):
19
    """
Tiago Pereira's avatar
Tiago Pereira committed
20
21
22
23
24
    Trainer for siamese networks:
     
    Chopra, Sumit, Raia Hadsell, and Yann LeCun. "Learning a similarity metric discriminatively, with application to
    face verification." 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'05). Vol. 1. IEEE, 2005.
    
25
26

    **Parameters**
27

Tiago Pereira's avatar
Tiago Pereira committed
28
29
    train_data_shuffler:
      The data shuffler used for batching data for training
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
30

Tiago Pereira's avatar
Tiago Pereira committed
31
    iterations:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
32
      Maximum number of iterations
Tiago Pereira's avatar
Tiago Pereira committed
33
34
      
    snapshot:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
35
      Will take a snapshot of the network at every `n` iterations
Tiago Pereira's avatar
Tiago Pereira committed
36
37
38
      
    validation_snapshot:
      Test with validation each `n` iterations
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
39
40
41
42

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

Tiago Pereira's avatar
Tiago Pereira committed
43
44
    temp_dir: str
      The output directory
45

Tiago Pereira's avatar
Tiago Pereira committed
46
    verbosity_level:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
47

48
49
    """

50
    def __init__(self,
Tiago Pereira's avatar
Tiago Pereira committed
51
                 train_data_shuffler,
52
53
54

                 ###### training options ##########
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
55
56
                 snapshot=500,
                 validation_snapshot=100,
57
58

                 ## Analizer
Tiago Pereira's avatar
Tiago Pereira committed
59
                 analizer=SoftmaxAnalizer(),
60

Tiago Pereira's avatar
Tiago Pereira committed
61
62
                 # Temporatu dir
                 temp_dir="siamese_cnn",
63

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
64
65
                 verbosity_level=2
                 ):
66

Tiago Pereira's avatar
Tiago Pereira committed
67
        self.train_data_shuffler = train_data_shuffler
68
        self.temp_dir = temp_dir
69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        self.iterations = iterations
        self.snapshot = snapshot
        self.validation_snapshot = validation_snapshot

        # Training variables used in the fit
        self.summaries_train = None
        self.train_summary_writter = None

        # Validation data
        self.validation_summary_writter = None

        # Analizer
        self.analizer = analizer
        self.global_step = None

        self.session = None

Tiago Pereira's avatar
Tiago Pereira committed
87
88
89
90
91
92
93
94
95
96
        self.graph = None
        self.loss = None
        self.predictor = None
        self.optimizer_class = None
        self.learning_rate = None
        # Training variables used in the fit
        self.optimizer = None
        self.data_ph = None
        self.label_ph = None
        self.saver = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
97

Tiago Pereira's avatar
Tiago Pereira committed
98
        bob.core.log.set_verbosity_level(logger, verbosity_level)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
99

Tiago Pereira's avatar
Tiago Pereira committed
100
101
102
        # Creating the session
        self.session = Session.instance(new=True).session
        self.from_scratch = True
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
103

Tiago Pereira's avatar
Tiago Pereira committed
104
        bob.core.log.set_verbosity_level(logger, verbosity_level)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
105

Tiago Pereira's avatar
Tiago Pereira committed
106
107
108
109
    def create_network_from_scratch(self,
                                    graph,
                                    optimizer=tf.train.AdamOptimizer(),
                                    loss=None,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
110

Tiago Pereira's avatar
Tiago Pereira committed
111
112
113
                                    # Learning rate
                                    learning_rate=None,
                                    ):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
114

Tiago Pereira's avatar
Tiago Pereira committed
115
116
        self.data_ph = self.train_data_shuffler("data")
        self.label_ph = self.train_data_shuffler("label")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
117

Tiago Pereira's avatar
Tiago Pereira committed
118
119
120
        self.graph = graph
        if "left" and "right" not in self.graph:
            raise ValueError("`graph` should be a dictionary with two elements (`left`and `right`)")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
121

Tiago Pereira's avatar
Tiago Pereira committed
122
123
124
125
126
127
        self.loss = loss
        self.predictor = self.loss(self.label_ph,
                                   self.graph["left"],
                                   self.graph["right"])
        self.optimizer_class = optimizer
        self.learning_rate = learning_rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
128

Tiago Pereira's avatar
Tiago Pereira committed
129
130
        # TODO: find an elegant way to provide this as a parameter of the trainer
        self.global_step = tf.Variable(0, trainable=False, name="global_step")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
131

Tiago Pereira's avatar
Tiago Pereira committed
132
133
134
        # Saving all the variables
        self.saver = tf.train.Saver(var_list=tf.global_variables())
        tf.add_to_collection("global_step", self.global_step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
135

Tiago Pereira's avatar
Tiago Pereira committed
136
137
138
139
140
141
142
143
        # Saving the pointers to the graph
        tf.add_to_collection("graph_left", self.graph['left'])
        tf.add_to_collection("graph_right", self.graph['right'])

        # Saving pointers to the loss
        tf.add_to_collection("predictor_loss", self.predictor['loss'])
        tf.add_to_collection("predictor_between_class_loss", self.predictor['between_class'])
        tf.add_to_collection("predictor_within_class_loss", self.predictor['within_class'])
144

Tiago Pereira's avatar
Tiago Pereira committed
145
146
147
        # Saving the pointers to the placeholders
        tf.add_to_collection("data_ph_left", self.data_ph['left'])
        tf.add_to_collection("data_ph_right", self.data_ph['right'])
Tiago Pereira's avatar
Tiago Pereira committed
148
        tf.add_to_collection("label_ph", self.label_ph)
149

Tiago Pereira's avatar
Tiago Pereira committed
150
151
        # Preparing the optimizer
        self.optimizer_class._learning_rate = self.learning_rate
Tiago Pereira's avatar
Tiago Pereira committed
152
        self.optimizer = self.optimizer_class.minimize(self.predictor['loss'], global_step=self.global_step)
Tiago Pereira's avatar
Tiago Pereira committed
153
154
        tf.add_to_collection("optimizer", self.optimizer)
        tf.add_to_collection("learning_rate", self.learning_rate)
155

Tiago Pereira's avatar
Tiago Pereira committed
156
157
        self.summaries_train = self.create_general_summary()
        tf.add_to_collection("summaries_train", self.summaries_train)
158

Tiago Pereira's avatar
Tiago Pereira committed
159
160
        # Creating the variables
        tf.global_variables_initializer().run(session=self.session)
161

162
    def create_network_from_file(self, model_from_file, clear_devices=True):
Tiago Pereira's avatar
Tiago Pereira committed
163
164

        #saver = self.architecture.load(self.model_from_file, clear_devices=False)
165
        self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices)
Tiago Pereira's avatar
Tiago Pereira committed
166
167
168
169
170
171
172
        self.saver.restore(self.session, model_from_file)

        # Loading the graph from the graph pointers
        self.graph = dict()
        self.graph['left'] = tf.get_collection("graph_left")[0]
        self.graph['right'] = tf.get_collection("graph_right")[0]

Tiago Pereira's avatar
Tiago Pereira committed
173
        # Loading the placeholders from the pointers
Tiago Pereira's avatar
Tiago Pereira committed
174
175
176
177
178
        self.data_ph = dict()
        self.data_ph['left'] = tf.get_collection("data_ph_left")[0]
        self.data_ph['right'] = tf.get_collection("data_ph_right")[0]
        self.label_ph = tf.get_collection("label_ph")[0]

Tiago Pereira's avatar
Tiago Pereira committed
179
180
181
182
183
        # Loading loss from the pointers
        self.predictor = dict()
        self.predictor['loss'] = tf.get_collection("predictor_loss")[0]
        self.predictor['between_class'] = tf.get_collection("predictor_between_class_loss")[0]
        self.predictor['within_class'] = tf.get_collection("predictor_within_class_loss")[0]
Tiago Pereira's avatar
Tiago Pereira committed
184

Tiago Pereira's avatar
Tiago Pereira committed
185
        # Loading other elements
Tiago Pereira's avatar
Tiago Pereira committed
186
187
188
189
190
191
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
        self.global_step = tf.get_collection("global_step")[0]
        self.from_scratch = False

192
193
    def get_feed_dict(self, data_shuffler):

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
194
        [batch_left, batch_right, labels] = data_shuffler.get_batch()
195

Tiago Pereira's avatar
Tiago Pereira committed
196
197
198
        feed_dict = {self.data_ph['left']: batch_left,
                     self.data_ph['right']: batch_right,
                     self.label_ph: labels}
199
200
201

        return feed_dict

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
202
    def fit(self, step):
203

Tiago Pereira's avatar
Tiago Pereira committed
204
205
206
        feed_dict = self.get_feed_dict(self.train_data_shuffler)
        _, l, bt_class, wt_class, lr, summary = self.session.run([
                                                self.optimizer,
Tiago Pereira's avatar
Tiago Pereira committed
207
208
                                                self.predictor['loss'], self.predictor['between_class'],
                                                self.predictor['within_class'],
Tiago Pereira's avatar
Tiago Pereira committed
209
                                                self.learning_rate, self.summaries_train], feed_dict=feed_dict)
210
211
212
213
214
215
216

        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)

    def create_general_summary(self):

        # Train summary
Tiago Pereira's avatar
Tiago Pereira committed
217
218
219
        tf.summary.scalar('loss', self.predictor['loss'])
        tf.summary.scalar('between_class_loss', self.predictor['between_class'])
        tf.summary.scalar('within_class_loss', self.predictor['within_class'])
220
221
        tf.summary.scalar('lr', self.learning_rate)
        return tf.summary.merge_all()