lda.cpp 20.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
/**
 * @author Andre Anjos <andre.anjos@idiap.ch>
 * @date Thu 16 Jan 2014 17:09:04 CET
 *
 * @brief Python bindings to LDA trainers
 *
 * Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
 */

André Anjos's avatar
André Anjos committed
10
11
12
13
#define BOB_LEARN_LINEAR_MODULE
#include <bob.blitz/cppapi.h>
#include <bob.blitz/cleanup.h>
#include <bob.learn.linear/api.h>
14
#include <structmember.h>
15
#include <bob.extension/documentation.h>
16
17
18
19
20

/*************************************************
 * Implementation of FisherLDATrainer base class *
 *************************************************/

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
static auto LDA_doc = bob::extension::ClassDoc(
  BOB_EXT_MODULE_PREFIX ".FisherLDATrainer",
  "Trains a :py:class:`Machine` to perform Fisher's Linear Discriminant Analysis (LDA).",
  "LDA finds the projection matrix W that allows us to linearly project the data matrix X to another (sub) space in which the between-class and within-class variances are jointly optimized: the between-class variance is maximized while the with-class is minimized. "
  "The (inverse) cost function for this criteria can be posed as the following:\n\n"
  ".. math::\n\n"
  "   J(W) = \\frac{W^T S_b W}{W^T S_w W}\n\n"
  "where:\n\n"
  ":math:`W`\n\n  the transformation matrix that converts X into the LD space\n\n"
  ":math:`S_b`\n\n  the between-class scatter; it has dimensions (X.shape[0], X.shape[0]) and is defined as :math:`S_b = \\sum_{k=1}^K N_k (m_k-m)(m_k-m)^T`, with :math:`K` equal to the number of classes.\n\n"
  ":math:`S_w`\n\n  the within-class scatter; it also has dimensions (X.shape[0], X.shape[0]) and is defined as :math:`S_w = \\sum_{k=1}^K \\sum_{n \\in C_k} (x_n-m_k)(x_n-m_k)^T`, with :math:`K` equal to the number of classes and :math:`C_k` a set representing all samples for class :math:`k`.\n\n"
  ":math:`m_k`\n\n  the class *k* empirical mean, defined as :math:`m_k = \\frac{1}{N_k}\\sum_{n \\in C_k} x_n`\n\n"
  ":math:`m`\n\n  the overall set empirical mean, defined as :math:`m = \\frac{1}{N}\\sum_{n=1}^N x_n = \\frac{1}{N}\\sum_{k=1}^K N_k m_k`\n\n"
  ".. note::  A scatter matrix equals the covariance matrix if we remove the division factor.\n\n"
  "Because this cost function is convex, you can just find its maximum by solving :math:`dJ/dW = 0`. "
  "This problem can be re-formulated as finding the eigen-values (:math:`\\lambda_i`) that solve the following condition:\n\n"
  ".. math::\n\n"
  "   S_b &= \\lambda_i Sw \\text{ or} \\\\\n"
  "  (Sb - \\lambda_i Sw) &= 0\n\n"
  "The respective eigen-vectors that correspond to the eigen-values :math:`\\lambda_i` form W."
).add_constructor(bob::extension::FunctionDoc(
  "FisherLDATrainer",
  "Constructs a new FisherLDATrainer",
  "Objects of this class can be initialized in two ways. "
  "In the first variant, the user creates a new trainer from discrete flags indicating a couple of optional parameters. "
  "If ``use_pinv`` is set to ``True``, use the pseudo-inverse to calculate :math:`S_w^{-1} S_b` and then perform eigen value decomposition (using LAPACK's ``dgeev``) instead of using (the more numerically stable) LAPACK's ``dsyvgd`` to solve the generalized symmetric-definite eigen-problem of the form :math:`S_b v=(\\lambda) S_w v`.\n\n"
  ".. note::\n\n"
  "   Using the pseudo-inverse for LDA is only recommended if you cannot make it work using the default method (via ``dsyvg``).\n"
  "   It is slower and requires more machine memory to store partial values of the pseudo-inverse and the dot product :math:`S_w^{-1} S_b`.\n\n"
  "``strip_to_rank`` specifies how to calculate the final size of the to-be-trained :py:class:`bob.learn.linear.Machine`. "
  "The default setting (``True``), makes the trainer return only the K-1 eigen-values/vectors limiting the output to the rank of :math:`S_w^{-1} S_b`. "
  "If you set this value to ``False``, the it returns all eigen-values/vectors of :math:`S_w^{-1} Sb`, including the ones that are supposed to be zero.\n\n"
  "The second initialization variant allows the user to deep copy an object of the same type creating a new identical object."
)
.add_prototype("[use_pinv, strip_to_rank]", "")
.add_prototype("other", "")
.add_parameter("use_pinv", "bool", "[Default: ``False``] use the pseudo-inverse to calculate :math:`S_w^{-1} S_b`?")
.add_parameter("strip_to_rank", "bool", "[Default: ``True``] return only the non-zero eigen-values/vectors")
.add_parameter("other", ":py:class:`FisherLDATrainer`", "The trainer to copy-construct")
);
61
62
63

static int PyBobLearnLinearFisherLDATrainer_init_bools
(PyBobLearnLinearFisherLDATrainerObject* self, PyObject* args, PyObject* kwds) {
64
BOB_TRY
65
  /* Parses input arguments in a single shot */
66
  char** kwlist = LDA_doc.kwlist(0);
67
68
69
70
71
72
73
74
75
76
77
78
79

  PyObject* use_pinv = Py_False;
  PyObject* strip_to_rank = Py_True;

  if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", kwlist,
        &use_pinv, &strip_to_rank)) return -1;

  int use_pinv_ = PyObject_IsTrue(use_pinv);
  if (use_pinv_ == -1) return -1;

  int strip_to_rank_ = PyObject_IsTrue(strip_to_rank);
  if (strip_to_rank_ == -1) return -1;

80
  self->cxx = new bob::learn::linear::FisherLDATrainer(use_pinv_, strip_to_rank_);
81
  return 0;
82
BOB_CATCH_MEMBER("constructor", -1)
83
84
85
86
}

static int PyBobLearnLinearFisherLDATrainer_init_copy
(PyBobLearnLinearFisherLDATrainerObject* self, PyObject* args, PyObject* kwds) {
87
BOB_TRY
88
  /* Parses input arguments in a single shot */
89
  char** kwlist = LDA_doc.kwlist(1);
90
91
92
93
94
95
96
97

  PyObject* other = 0;

  if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!", kwlist,
        &PyBobLearnLinearFisherLDATrainer_Type, &other)) return -1;

  auto copy = reinterpret_cast<PyBobLearnLinearFisherLDATrainerObject*>(other);

98
  self->cxx = new bob::learn::linear::FisherLDATrainer(*(copy->cxx));
99
100

  return 0;
101
BOB_CATCH_MEMBER("constructor", -1)
102
103
104
105
106
107
108
109
110
}

int PyBobLearnLinearFisherLDATrainer_Check(PyObject* o) {
  return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnLinearFisherLDATrainer_Type));
}

static int PyBobLearnLinearFisherLDATrainer_init
(PyBobLearnLinearFisherLDATrainerObject* self, PyObject* args, PyObject* kwds) {

André Anjos's avatar
André Anjos committed
111
  Py_ssize_t nargs = (args?PyTuple_Size(args):0) + (kwds?PyDict_Size(kwds):0);
112
113
114
115

  switch (nargs) {

    case 0: //default initializer
116
    case 2: //two bools
117
118
119
120
121
122
123
124
125
126
127
128
129
130
      return PyBobLearnLinearFisherLDATrainer_init_bools(self, args, kwds);

    case 1:
      {
        PyObject* arg = 0; ///< borrowed (don't delete)
        if (PyTuple_Size(args)) arg = PyTuple_GET_ITEM(args, 0);
        else {
          PyObject* tmp = PyDict_Values(kwds);
          auto tmp_ = make_safe(tmp);
          arg = PyList_GET_ITEM(tmp, 0);
        }

        if (PyBobLearnLinearFisherLDATrainer_Check(arg)) {
          return PyBobLearnLinearFisherLDATrainer_init_copy(self, args, kwds);
131
        } else {
132
133
          return PyBobLearnLinearFisherLDATrainer_init_bools(self, args, kwds);
        }
André Anjos's avatar
André Anjos committed
134
        PyErr_Format(PyExc_TypeError, "cannot initialize `%s' with `%s' (see help)", Py_TYPE(self)->tp_name, Py_TYPE(arg)->tp_name);
135
136
137
      }
      break;
    default:
138
      PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - `%s' requires 0 to 2 arguments, but you provided %" PY_FORMAT_SIZE_T "d (see help)", Py_TYPE(self)->tp_name, nargs);
139

140
  } // switch
141
142
143
144
145
146
147
  return -1;
}

static void PyBobLearnLinearFisherLDATrainer_delete
(PyBobLearnLinearFisherLDATrainerObject* self) {

  delete self->cxx;
André Anjos's avatar
André Anjos committed
148
  Py_TYPE(self)->tp_free((PyObject*)self);
149
150
151
152
153
154
155
156

}

static PyObject* PyBobLearnLinearFisherLDATrainer_RichCompare
(PyBobLearnLinearFisherLDATrainerObject* self, PyObject* other, int op) {

  if (!PyBobLearnLinearFisherLDATrainer_Check(other)) {
    PyErr_Format(PyExc_TypeError, "cannot compare `%s' with `%s'",
André Anjos's avatar
André Anjos committed
157
        Py_TYPE(self)->tp_name, Py_TYPE(other)->tp_name);
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    return 0;
  }

  auto other_ = reinterpret_cast<PyBobLearnLinearFisherLDATrainerObject*>(other);

  switch (op) {
    case Py_EQ:
      if (self->cxx->operator==(*other_->cxx)) Py_RETURN_TRUE;
      Py_RETURN_FALSE;
      break;
    case Py_NE:
      if (self->cxx->operator!=(*other_->cxx)) Py_RETURN_TRUE;
      Py_RETURN_FALSE;
      break;
    default:
      Py_INCREF(Py_NotImplemented);
      return Py_NotImplemented;
  }
}

178
179
180
181
182
183
184
185
186
187
188
189
190
static auto train = bob::extension::FunctionDoc(
  "train",
  "Trains a given machine to perform Fisher/LDA discrimination",
  "After this method has been called, an input ``machine`` (or one allocated internally) will have the eigen-vectors of the :math:`S_w^{-1} S_b` product, arranged by decreasing energy. "
  "Each input data set represents data from a given input class. "
  "This method also returns the eigen-values allowing you to implement your own compression scheme.\n\n"
  "The user may provide or not an object of type :py:class:`bob.learn.linear.Machine` that will be set by this method. "
  "If provided, machine should have the correct number of inputs and outputs matching, respectively, the number of columns in the input data arrays ``X`` and the output of the method :py:meth:`output_size`.\n\n"
  "The value of ``X`` should be a sequence over as many 2D 64-bit floating point number arrays as classes in the problem. "
  "All arrays will be checked for conformance (identical number of columns). "
  "To accomplish this, either prepare a list with all your class observations organized in 2D arrays or pass a 3D array in which the first dimension (depth) contains as many elements as classes you want to discriminate.\n\n"
  ".. note::\n\n"
  "   We set at most :py:meth:`output_size` eigen-values and vectors on the passed machine.\n"
191
192
  "   You can compress the machine output further using :py:meth:`Machine.resize` if necessary.",
  true
193
194
195
196
197
198
199
)
.add_prototype("X, [machine]", "machine, eigen_values")
.add_parameter("X", "[array_like(2D, floats)] or array_like(3D, floats)", "The input data, separated to contain the training data per class in the first dimension")
.add_parameter("machine", ":py:class:`bob.learn.linear.Machine`", "The machine to be trained; this machine will be returned by this function")
.add_return("machine", ":py:class:`bob.learn.linear.Machine`", "The machine that has been trained; if given, identical to the ``machine`` parameter")
.add_return("eigen_values", "array_like(1D, floats)", "The eigen-values of the LDA projection.")
;
200
201
202
203
static PyObject* PyBobLearnLinearFisherLDATrainer_Train
(PyBobLearnLinearFisherLDATrainerObject* self,
 PyObject* args, PyObject* kwds) {

204
BOB_TRY
205
  /* Parses input arguments in a single shot */
206
  char** kwlist = train.kwlist();
207

208
  PyObject* X = 0;
209
210
  PyObject* machine = 0;

211
212
213
  if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O!", kwlist,
        &X, &PyBobLearnLinearMachine_Type, &machine)) return 0;

214
  /**
215
216
217
  // Note: strangely, if you pass dict.values(), this check does not work
  if (!PyIter_Check(X)) {
    PyErr_Format(PyExc_TypeError, "`%s' requires an iterable for parameter `X', but you passed `%s' which does not implement the iterator protocol", Py_TYPE(self)->tp_name, Py_TYPE(X)->tp_name);
218
    return 0;
219
  }
220
  **/
221
222
223

  /* Checks and converts all entries */
  std::vector<blitz::Array<double,2> > Xseq;
224
  std::vector<boost::shared_ptr<PyBlitzArrayObject>> Xseq_;
225

226
227
228
  PyObject* iterator = PyObject_GetIter(X);
  if (!iterator) return 0;
  auto iterator_ = make_safe(iterator);
229

230
231
  while (PyObject* item = PyIter_Next(iterator)) {
    auto item_ = make_safe(item);
232

233
234
    PyBlitzArrayObject* bz = 0;

235
236
    if (!PyBlitzArray_Converter(item, &bz)) {
      PyErr_Format(PyExc_TypeError, "`%s' could not convert object of type `%s' at position %" PY_FORMAT_SIZE_T "d of input sequence `X' into an array - check your input", Py_TYPE(self)->tp_name, Py_TYPE(item)->tp_name, Xseq.size());
237
238
239
      return 0;
    }

240
    if (bz->ndim != 2 || bz->type_num != NPY_FLOAT64) {
241
      PyErr_Format(PyExc_TypeError, "`%s' only supports 2D 64-bit float arrays for input sequence `X' (or any other object coercible to that), but at position %" PY_FORMAT_SIZE_T "d I have found an object with %" PY_FORMAT_SIZE_T "d dimensions and with type `%s' which is not compatible - check your input", Py_TYPE(self)->tp_name, Xseq.size(), bz->ndim, PyBlitzArray_TypenumAsString(bz->type_num));
242
243
244
245
246
247
      Py_DECREF(bz);
      return 0;
    }

    Xseq_.push_back(make_safe(bz)); ///< prevents data deletion
    Xseq.push_back(*PyBlitzArrayCxx_AsBlitz<double,2>(bz)); ///< only a view!
248
249
250
  }

  if (PyErr_Occurred()) return 0;
251

252
253
254
  if (Xseq.size() < 2) {
    PyErr_Format(PyExc_RuntimeError, "`%s' requires an iterable for parameter `X' leading to, at least, two entries (representing two classes), but you have passed something that has only %" PY_FORMAT_SIZE_T "d entries", Py_TYPE(self)->tp_name, Xseq.size());
    return 0;
255
256
257
258
  }

  // evaluates the expected rank for the output, allocate eigens value array
  Py_ssize_t rank = self->cxx->output_size(Xseq);
259
  auto eigval = reinterpret_cast<PyBlitzArrayObject*>(PyBlitzArray_SimpleNew(NPY_FLOAT64, 1, &rank));
260
261
262
  auto eigval_ = make_safe(eigval); ///< auto-delete in case of problems

  // allocates a new machine if that was not given by the user
263
  boost::shared_ptr<PyObject> machine_;
264
265
266
267
268
269
270
  if (!machine) {
    machine = PyBobLearnLinearMachine_NewFromSize(Xseq[0].extent(1), rank);
    machine_ = make_safe(machine); ///< auto-delete in case of problems
  }

  auto pymac = reinterpret_cast<PyBobLearnLinearMachineObject*>(machine);

271
  auto eigval_bz = PyBlitzArrayCxx_AsBlitz<double,1>(eigval);
272
  self->cxx->train(*pymac->cxx, *eigval_bz, Xseq);
273
274

  // all went fine, pack machine and eigen-values to return
275
  return Py_BuildValue("ON", machine, PyBlitzArray_AsNumpyArray(eigval, 0));
276
BOB_CATCH_FUNCTION("train", 0)
277
278
}

279
280
281
282
283
284
285
static auto output_size = bob::extension::FunctionDoc(
  "output_size",
  "Returns the expected size of the output (or the number of eigen-values returned) given the data",
  "This number could be either :math:`K-1` (where :math:`K` is number of classes) or the number of columns (features) in ``X``, depending on the setting of :py:attr:`strip_to_rank`. "
  "This method should be used to setup linear machines and input vectors prior to feeding them into this trainer.\n\n"
  "The value of ``X`` should be a sequence over as many 2D 64-bit floating point number arrays as classes in the problem. "
  "All arrays will be checked for conformance (identical number of columns). "
286
287
  "To accomplish this, either prepare a list with all your class observations organized in 2D arrays or pass a 3D array in which the first dimension (depth) contains as many elements as classes you want to discriminate.",
  true
288
289
290
291
292
)
.add_prototype("X","size")
.add_parameter("X", "[array_like(2D, floats)] or array_like(3D, floats)", "The input data, separated to contain the training data per class in the first dimension")
.add_return("size", "int", "The number of eigen-vectors/values that will be created in a call to :py:meth:`train`, given the same input data ``X``")
;
293
294
295
296
static PyObject* PyBobLearnLinearFisherLDATrainer_OutputSize
(PyBobLearnLinearFisherLDATrainerObject* self,
 PyObject* args, PyObject* kwds) {

297
BOB_TRY
298
  /* Parses input arguments in a single shot */
299
  char** kwlist = output_size.kwlist();
300

301
  PyObject* X = 0;
302

303
304
305
  if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &X)) return 0;

  if (!PySequence_Check(X)) {
André Anjos's avatar
André Anjos committed
306
    PyErr_Format(PyExc_TypeError, "`%s' requires an input sequence for parameter `X', but you passed a `%s' which does not implement the sequence protocol", Py_TYPE(self)->tp_name, Py_TYPE(X)->tp_name);
307
308
    return 0;
  }
309
310
311

  /* Checks and converts all entries */
  std::vector<blitz::Array<double,2> > Xseq;
312
  std::vector<boost::shared_ptr<PyBlitzArrayObject>> Xseq_;
313
314
315
  Py_ssize_t size = PySequence_Fast_GET_SIZE(X);

  if (size < 2) {
André Anjos's avatar
André Anjos committed
316
    PyErr_Format(PyExc_RuntimeError, "`%s' requires an input sequence for parameter `X' with at least two entries (representing two classes), but you have passed something that has only %" PY_FORMAT_SIZE_T "d entries", Py_TYPE(self)->tp_name, size);
317
318
319
320
321
322
323
324
325
326
327
328
    return 0;
  }

  Xseq.reserve(size);
  Xseq_.reserve(size);

  for (Py_ssize_t k=0; k<size; ++k) {

    PyBlitzArrayObject* bz = 0;
    PyObject* borrowed = PySequence_Fast_GET_ITEM(X, k);

    if (!PyBlitzArray_Converter(borrowed, &bz)) {
André Anjos's avatar
André Anjos committed
329
      PyErr_Format(PyExc_TypeError, "`%s' could not convert object of type `%s' at position %" PY_FORMAT_SIZE_T "d of input sequence `X' into an array - check your input", Py_TYPE(self)->tp_name, Py_TYPE(borrowed)->tp_name, k);
330
331
332
      return 0;
    }

333
    if (bz->ndim != 2 || bz->type_num != NPY_FLOAT64) {
André Anjos's avatar
André Anjos committed
334
      PyErr_Format(PyExc_TypeError, "`%s' only supports 2D 64-bit float arrays for input sequence `X' (or any other object coercible to that), but at position %" PY_FORMAT_SIZE_T "d I have found an object with %" PY_FORMAT_SIZE_T "d dimensions and with type `%s' which is not compatible - check your input", Py_TYPE(self)->tp_name, k, bz->ndim, PyBlitzArray_TypenumAsString(bz->type_num));
335
336
337
338
339
340
341
342
343
344
      Py_DECREF(bz);
      return 0;
    }

    Xseq_.push_back(make_safe(bz)); ///< prevents data deletion
    Xseq.push_back(*PyBlitzArrayCxx_AsBlitz<double,2>(bz)); ///< only a view!

  }

  return Py_BuildValue("n", self->cxx->output_size(Xseq));
345
BOB_CATCH_MEMBER("output_size", 0)
346
347
348
349
}

static PyMethodDef PyBobLearnLinearFisherLDATrainer_methods[] = {
  {
350
    train.name(),
351
352
    (PyCFunction)PyBobLearnLinearFisherLDATrainer_Train,
    METH_VARARGS|METH_KEYWORDS,
353
    train.doc()
354
355
  },
  {
356
    output_size.name(),
357
358
    (PyCFunction)PyBobLearnLinearFisherLDATrainer_OutputSize,
    METH_VARARGS|METH_KEYWORDS,
359
    output_size.doc()
360
361
362
363
  },
  {0} /* Sentinel */
};

364
365
366
367
368
369
static auto use_pinv = bob::extension::VariableDoc(
  "use_pinv",
  "bool",
  "Use the pseudo-inverse?",
  "If ``True``, use the pseudo-inverse to calculate :math:`S_w^{-1} S_b` and then perform the eigen value decomposition (using LAPACK's ``dgeev``) instead of using (the more numerically stable) LAPACK's ``dsyvgd`` to solve the generalized symmetric-definite eigen-problem of the form :math:`S_b v=(\\lambda) S_w v`."
);
370
371
static PyObject* PyBobLearnLinearFisherLDATrainer_getUsePinv
(PyBobLearnLinearFisherLDATrainerObject* self, void* /*closure*/) {
372
BOB_TRY
373
374
  if (self->cxx->getUsePseudoInverse()) Py_RETURN_TRUE;
  Py_RETURN_FALSE;
375
BOB_CATCH_MEMBER("use_pinv", 0)
376
377
378
379
}

static int PyBobLearnLinearFisherLDATrainer_setUsePinv
(PyBobLearnLinearFisherLDATrainerObject* self, PyObject* o, void* /*closure*/) {
380
BOB_TRY
381
382
383
384
  int istrue = PyObject_IsTrue(o);

  if (istrue == -1) return -1;

385
  self->cxx->setUsePseudoInverse(istrue);
386
387

  return 0;
388
BOB_CATCH_MEMBER("use_pinv", -1)
389
390
}

391
392
393
394
395
396
397
static auto strip_to_rank = bob::extension::VariableDoc(
  "strip_to_rank",
  "bool",
  "Only return the non-zero eigen-values/vectors?",
  "If ``True``, strip the resulting LDA projection matrix to keep only the eigen-vectors with non-zero eigenvalues. "
  "Otherwise the full projection matrix is returned."
);
398
399
static PyObject* PyBobLearnLinearFisherLDATrainer_getStripToRank
(PyBobLearnLinearFisherLDATrainerObject* self, void* /*closure*/) {
400
BOB_TRY
401
402
  if (self->cxx->getStripToRank()) Py_RETURN_TRUE;
  Py_RETURN_FALSE;
403
BOB_CATCH_MEMBER("strip_to_rank", 0)
404
405
406
407
}

static int PyBobLearnLinearFisherLDATrainer_setStripToRank
(PyBobLearnLinearFisherLDATrainerObject* self, PyObject* o, void* /*closure*/) {
408
BOB_TRY
409
410
411
412
413

  int istrue = PyObject_IsTrue(o);

  if (istrue == -1) return -1;

414
  self->cxx->setStripToRank(istrue);
415
416

  return 0;
417
BOB_CATCH_MEMBER("strip_to_rank", -1)
418
419
420
421
}

static PyGetSetDef PyBobLearnLinearFisherLDATrainer_getseters[] = {
    {
422
      use_pinv.name(),
423
424
      (getter)PyBobLearnLinearFisherLDATrainer_getUsePinv,
      (setter)PyBobLearnLinearFisherLDATrainer_setUsePinv,
425
      use_pinv.doc(),
426
427
428
      0
    },
    {
429
      strip_to_rank.name(),
430
431
      (getter)PyBobLearnLinearFisherLDATrainer_getStripToRank,
      (setter)PyBobLearnLinearFisherLDATrainer_setStripToRank,
432
      strip_to_rank.doc(),
433
434
435
436
437
      0
    },
    {0}  /* Sentinel */
};

438
// LDA Trainer
439
PyTypeObject PyBobLearnLinearFisherLDATrainer_Type = {
440
441
  PyVarObject_HEAD_INIT(0,0)
  0
442
};
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466

bool init_BobLearnLinearLDA(PyObject* module)
{
  // LDA Trainer
  PyBobLearnLinearFisherLDATrainer_Type.tp_name = LDA_doc.name();
  PyBobLearnLinearFisherLDATrainer_Type.tp_basicsize = sizeof(PyBobLearnLinearFisherLDATrainerObject);
  PyBobLearnLinearFisherLDATrainer_Type.tp_flags = Py_TPFLAGS_DEFAULT;
  PyBobLearnLinearFisherLDATrainer_Type.tp_doc = LDA_doc.doc();

  // set the functions
  PyBobLearnLinearFisherLDATrainer_Type.tp_new = PyType_GenericNew;
  PyBobLearnLinearFisherLDATrainer_Type.tp_init = reinterpret_cast<initproc>(PyBobLearnLinearFisherLDATrainer_init);
  PyBobLearnLinearFisherLDATrainer_Type.tp_dealloc = reinterpret_cast<destructor>(PyBobLearnLinearFisherLDATrainer_delete);
  PyBobLearnLinearFisherLDATrainer_Type.tp_methods = PyBobLearnLinearFisherLDATrainer_methods;
  PyBobLearnLinearFisherLDATrainer_Type.tp_getset = PyBobLearnLinearFisherLDATrainer_getseters;
  PyBobLearnLinearFisherLDATrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnLinearFisherLDATrainer_RichCompare);

  // check that everyting is fine
  if (PyType_Ready(&PyBobLearnLinearFisherLDATrainer_Type) < 0) return false;

  // add the type to the module
  Py_INCREF(&PyBobLearnLinearFisherLDATrainer_Type);
  return PyModule_AddObject(module, "FisherLDATrainer", (PyObject*)&PyBobLearnLinearFisherLDATrainer_Type) >= 0;
}