rprop.h 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/**
 * @date Wed Jul 6 17:32:35 2011 +0200
 * @author Andre Anjos <andre.anjos@idiap.ch>
 * @author Laurent El Shafey<Laurent.El-Shafey@idiap.ch>
 *
 * @brief A MLP trainer based on resilient back-propagation: A Direct Adaptive
 * Method for Faster Backpropagation Learning: The RPROP Algorithm, by Martin
 * Riedmiller and Heinrich Braun on IEEE International Conference on Neural
 * Networks, pp. 586--591, 1993.
 *
 * Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
 */

#ifndef BOB_LEARN_MLP_RPROP_H
#define BOB_LEARN_MLP_RPROP_H

#include <vector>
#include <boost/function.hpp>

20
21
#include <bob.learn.mlp/machine.h>
#include <bob.learn.mlp/trainer.h>
22
23
24
25
26
27
28
29
30

namespace bob { namespace learn { namespace mlp {

  /**
   * @brief Sets an MLP to perform discrimination based on RProp: A Direct
   * Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm,
   * by Martin Riedmiller and Heinrich Braun on IEEE International Conference
   * on Neural Networks, pp. 586--591, 1993.
   */
31
  class RProp: public Trainer {
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

    public: //api

      /**
       * @brief Initializes a new RProp trainer according to a given
       * training batch size.
       *
       * @param batch_size The number of examples passed at each iteration.
       * This should be a big number (tens of samples) - Resilient
       * Back-propagation is a <b>batch</b> algorithm, it requires large sample
       * sizes
       *
       * @param cost This is the cost function to use for the current training.
       *
       * @note Good values for batch sizes are tens of samples. This may affect
       * the convergence.
       */
      RProp(size_t batch_size,
          boost::shared_ptr<Cost> cost);

      /**
       * @brief Initializes a new RProp trainer according to a given
       * machine settings and a training batch size.
       *
       * @param batch_size The number of examples passed at each iteration.
       * This should be a big number (tens of samples) - Resilient
       * Back-propagation is a <b>batch</b> algorithm, it requires large sample
       * sizes
       *
       * @param cost This is the cost function to use for the current training.
       *
       * @param machine Clone this machine weights and prepare the trainer
       * internally mirroring machine properties.
       *
       * @note Good values for batch sizes are tens of samples. This may affect
       * the convergence.
       */
      RProp(size_t batch_size,
          boost::shared_ptr<Cost> cost,
          const Machine& machine);

      /**
       * @brief Initializes a new RProp trainer according to a
       * given machine settings and a training batch size.
       *
       * @param batch_size The number of examples passed at each iteration.
       * This should be a big number (tens of samples) - Resilient
       * Back-propagation is a <b>batch</b> algorithm, it requires large sample
       * sizes
       *
       * @param cost This is the cost function to use for the current training.
       *
       * @param machine Clone this machine weights and prepare the trainer
       * internally mirroring machine properties.
       *
       * @note Good values for batch sizes are tens of samples. BackProp is not
       * necessarily a "batch" training algorithm, but performs in a smoother
       * if the batch size is larger. This may also affect the convergence.
       *
       * @param train_biases A boolean, indicating if we need to train the
       * biases or not.
       *
       * You can also change default values for the learning rate and momentum.
       * By default we train w/o any momenta.
       *
       * If you want to adjust a potential learning rate decay, you can and
       * should do it outside the scope of this trainer, in your own way.
       */
      RProp(size_t batch_size, boost::shared_ptr<Cost> cost,
          const Machine& machine, bool train_biases);

      /**
       * @brief Destructor virtualisation
       */
      virtual ~RProp();

      /**
       * @brief Copy construction.
       */
      RProp(const RProp& other);

      /**
       * @brief Copy operator
       */
      RProp& operator=(const RProp& other);

      /**
       * @brief Re-initializes the whole training apparatus to start training
       * a new machine. This will effectively reset all Delta matrices to their
       * intial values and set the previous derivatives to zero as described on
       * the section II.C of the RProp paper.
       */
      void reset();

      /**
       * @brief Initialize the internal buffers for the current machine
       */
      virtual void initialize(const Machine& machine);

      /**
       * @brief Trains the MLP to perform discrimination. The training is
       * executed outside the machine context, but uses all the current machine
       * layout. The given machine is updated with new weights and biases on
       * the end of the training that is performed a single time. Iterate as
       * much as you want to refine the training.
       *
       * The machine given as input is checked for compatibility with the
       * current initialized settings. If the two are not compatible, an
       * exception is thrown.
       *
       * Note: In RProp, training is done in batches. The number of rows in the
       * input (and target) determines the batch size. If the batch size
       * currently set is incompatible with the given data an exception is
       * raised.
       *
       * Note2: The machine is not initialized randomly at each train() call.
       * It is your task to call MLP::randomize() once on the machine you
       * want to train and then call train() as many times as you think are
       * necessary. This design allows for a training criteria to be encoded
       * outside the scope of this trainer and to this type to focus only on
       input, target applying the training when requested to.
       */
      void train(Machine& machine,
          const blitz::Array<double,2>& input,
          const blitz::Array<double,2>& target);

      /**
       * @brief This is a version of the train() method above, which does no
       * compatibility check on the input machine.
       */
      void train_(Machine& machine,
          const blitz::Array<double,2>& input,
          const blitz::Array<double,2>& target);

      /**
       * Accessors for algorithm parameters
       */

      /**
       * @brief Gets the de-enforcement parameter (default is 0.5)
       */
      double getEtaMinus() const { return m_eta_minus; }

      /**
       * @brief Sets the de-enforcement parameter (default is 0.5)
       */
      void setEtaMinus(double v) { m_eta_minus = v;    }

      /**
       * @brief Gets the enforcement parameter (default is 1.2)
       */
      double getEtaPlus() const { return m_eta_plus; }

      /**
       * @brief Sets the enforcement parameter (default is 1.2)
       */
      void setEtaPlus(double v) { m_eta_plus = v;    }

      /**
       * @brief Gets the initial weight update (default is 0.1)
       */
      double getDeltaZero() const { return m_delta_zero; }

      /**
       * @brief Sets the initial weight update (default is 0.1)
       */
      void setDeltaZero(double v) { m_delta_zero = v;    }

      /**
       * @brief Gets the minimal weight update (default is 1e-6)
       */
      double getDeltaMin() const { return m_delta_min; }

      /**
       * @brief Sets the minimal weight update (default is 1e-6)
       */
      void setDeltaMin(double v) { m_delta_min = v;    }

      /**
       * @brief Gets the maximal weight update (default is 50.0)
       */
      double getDeltaMax() const { return m_delta_max; }

      /**
       * @brief Sets the maximal weight update (default is 50.0)
       */
      void setDeltaMax(double v) { m_delta_max = v;    }

      /**
       * @brief Returns the deltas
       */
      const std::vector<blitz::Array<double,2> >& getDeltas() const { return m_delta; }

      /**
       * @brief Returns the deltas
       */
      const std::vector<blitz::Array<double,1> >& getBiasDeltas() const { return m_delta_bias; }

      /**
       * @brief Sets the deltas
       */
      void setDeltas(const std::vector<blitz::Array<double,2> >& v);

      /**
       * @brief Sets the deltas for a given index
       */
      void setDelta(const blitz::Array<double,2>& v, const size_t index);

      /**
       * @brief Sets the bias deltas
       */
      void setBiasDeltas(const std::vector<blitz::Array<double,1> >& v);

      /**
       * @brief Sets the bias deltas for a given index
       */
      void setBiasDelta(const blitz::Array<double,1>& v, const size_t index);

      /**
       * @brief Returns the derivatives of the cost wrt. the weights
       */
      const std::vector<blitz::Array<double,2> >& getPreviousDerivatives() const { return m_prev_deriv; }

      /**
       * @brief Returns the derivatives of the cost wrt. the biases
       */
      const std::vector<blitz::Array<double,1> >& getPreviousBiasDerivatives() const { return m_prev_deriv_bias; }

      /**
       * @brief Sets the previous derivatives of the cost
       */
      void setPreviousDerivatives(const std::vector<blitz::Array<double,2> >& v);

      /**
       * @brief Sets the previous derivatives of the cost of a given index
       */
      void setPreviousDerivative(const blitz::Array<double,2>& v, const size_t index);

      /**
       * @brief Sets the previous derivatives of the cost (biases)
       */
      void setPreviousBiasDerivatives(const std::vector<blitz::Array<double,1> >& v);

      /**
       * @brief Sets the previous derivatives of the cost (biases) of a given
       * index
       */
      void setPreviousBiasDerivative(const blitz::Array<double,1>& v, const size_t index);

    private: //representation

      /**
       * Weight update -- calculates the weight-update using derivatives as
       * explained in Bishop's formula 5.53, page 243.
       *
       * Note: For RProp, specifically, we only care about the derivative's
       * sign, current and the previous. This is the place where standard
       * backprop and rprop diverge.
       *
       * For extra insight, double-check the Technical Report entitled "Rprop -
       * Description and Implementation Details" by Martin Riedmiller, 1994.
       * Just browse the internet for it. Keep it under your pillow ;-)
       */
      void rprop_weight_update(Machine& machine,
        const blitz::Array<double,2>& input);

      double m_eta_minus; ///< de-enforcement parameter (0.5)
      double m_eta_plus;  ///< enforcement parameter (1.2)
      double m_delta_zero;///< initial value for the weight change (0.1)
      double m_delta_min; ///< minimum value for the weight change (1e-6)
      double m_delta_max; ///< maximum value for the weight change (50.0)

      std::vector<blitz::Array<double,2> > m_delta; ///< R-prop weights deltas
      std::vector<blitz::Array<double,1> > m_delta_bias; ///< R-prop biases deltas

      std::vector<blitz::Array<double,2> > m_prev_deriv; ///< prev.weight deriv.
      std::vector<blitz::Array<double,1> > m_prev_deriv_bias; ///< pr.bias der.
  };

  /**
   * @}
   */
}}}

#endif /* BOB_LEARN_MLP_RPROP_H */