plda_trainer.cpp 23.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/**
 * @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
 * @date Wed 04 Feb 14:15:00 2015
 *
 * @brief Python API for bob::learn::em
 *
 * Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
 */

#include "main.h"
#include <boost/make_shared.hpp>
12
#include <boost/assign.hpp>
13
14

//Defining maps for each initializatio method
15
static const std::map<std::string, bob::learn::em::PLDATrainer::InitFMethod> FMethod = boost::assign::map_list_of
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
  ("RANDOM_F",  bob::learn::em::PLDATrainer::RANDOM_F)
  ("BETWEEN_SCATTER", bob::learn::em::PLDATrainer::BETWEEN_SCATTER)
  ;

static const std::map<std::string, bob::learn::em::PLDATrainer::InitGMethod> GMethod = boost::assign::map_list_of
  ("RANDOM_G",  bob::learn::em::PLDATrainer::RANDOM_G)
  ("WITHIN_SCATTER", bob::learn::em::PLDATrainer::WITHIN_SCATTER)
  ;

static const std::map<std::string, bob::learn::em::PLDATrainer::InitSigmaMethod> SigmaMethod = boost::assign::map_list_of
  ("RANDOM_SIGMA",  bob::learn::em::PLDATrainer::RANDOM_SIGMA)
  ("VARIANCE_G", bob::learn::em::PLDATrainer::VARIANCE_G)
  ("CONSTANT", bob::learn::em::PLDATrainer::CONSTANT)
  ("VARIANCE_DATA", bob::learn::em::PLDATrainer::VARIANCE_DATA)
  ;
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

//String to type
static inline bob::learn::em::PLDATrainer::InitFMethod string2FMethod(const std::string& o){
  auto it = FMethod.find(o);
  if (it == FMethod.end()) throw std::runtime_error("The given FMethod '" + o + "' is not known; choose one of ('RANDOM_F','BETWEEN_SCATTER')");
  else return it->second;
}

static inline bob::learn::em::PLDATrainer::InitGMethod string2GMethod(const std::string& o){
  auto it = GMethod.find(o);
  if (it == GMethod.end()) throw std::runtime_error("The given GMethod '" + o + "' is not known; choose one of ('RANDOM_G','WITHIN_SCATTER')");
  else return it->second;
}

static inline bob::learn::em::PLDATrainer::InitSigmaMethod string2SigmaMethod(const std::string& o){
  auto it = SigmaMethod.find(o);
  if (it == SigmaMethod.end()) throw std::runtime_error("The given SigmaMethod '" + o + "' is not known; choose one of ('RANDOM_SIGMA','VARIANCE_G', 'CONSTANT', 'VARIANCE_DATA')");
  else return it->second;
}

//Type to string
static inline const std::string& FMethod2string(bob::learn::em::PLDATrainer::InitFMethod o){
  for (auto it = FMethod.begin(); it != FMethod.end(); ++it) if (it->second == o) return it->first;
  throw std::runtime_error("The given FMethod type is not known");
}

static inline const std::string& GMethod2string(bob::learn::em::PLDATrainer::InitGMethod o){
  for (auto it = GMethod.begin(); it != GMethod.end(); ++it) if (it->second == o) return it->first;
  throw std::runtime_error("The given GMethod type is not known");
}

static inline const std::string& SigmaMethod2string(bob::learn::em::PLDATrainer::InitSigmaMethod o){
  for (auto it = SigmaMethod.begin(); it != SigmaMethod.end(); ++it) if (it->second == o) return it->first;
  throw std::runtime_error("The given SigmaMethod type is not known");
}


static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;}  /* converts PyObject to bool and returns false if object is NULL */

template <int N>
int list_as_vector(PyObject* list, std::vector<blitz::Array<double,N> >& vec)
{
  for (int i=0; i<PyList_GET_SIZE(list); i++)
  {
75
    PyBlitzArrayObject* blitz_object;
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    if (!PyArg_Parse(PyList_GetItem(list, i), "O&", &PyBlitzArray_Converter, &blitz_object)){
      PyErr_Format(PyExc_RuntimeError, "Expected numpy array object");
      return -1;
    }
    auto blitz_object_ = make_safe(blitz_object);
    vec.push_back(*PyBlitzArrayCxx_AsBlitz<double,N>(blitz_object));
  }
  return 0;
}


template <int N>
static PyObject* vector_as_list(const std::vector<blitz::Array<double,N> >& vec)
{
  PyObject* list = PyList_New(vec.size());
  for(size_t i=0; i<vec.size(); i++){
    blitz::Array<double,N> numpy_array = vec[i];
    PyObject* numpy_py_object = PyBlitzArrayCxx_AsNumpy(numpy_array);
    PyList_SET_ITEM(list, i, numpy_py_object);
  }
  return list;
}


/******************************************************************/
/************ Constructor Section *********************************/
/******************************************************************/


static auto PLDATrainer_doc = bob::extension::ClassDoc(
  BOB_EXT_MODULE_PREFIX ".PLDATrainer",
107
108
  "This class can be used to train the :math:`F`, :math:`G` and "
  " :math:`\\Sigma` matrices and the mean vector :math:`\\mu` of a PLDA model."
109
  "References: [ElShafey2014]_ [PrinceElder2007]_ [LiFu2012]_ ",
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
  ""
).add_constructor(
  bob::extension::FunctionDoc(
    "__init__",
    "Default constructor.\n Initializes a new PLDA trainer. The "
    "training stage will place the resulting components in the "
    "PLDABase.",
    "",
    true
  )
  .add_prototype("use_sum_second_order","")
  .add_prototype("other","")
  .add_prototype("","")

  .add_parameter("other", ":py:class:`bob.learn.em.PLDATrainer`", "A PLDATrainer object to be copied.")
  .add_parameter("use_sum_second_order", "bool", "")
);

static int PyBobLearnEMPLDATrainer_init_copy(PyBobLearnEMPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {

  char** kwlist = PLDATrainer_doc.kwlist(1);
  PyBobLearnEMPLDATrainerObject* o;
  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnEMPLDATrainer_Type, &o)){
    PLDATrainer_doc.print_usage();
    return -1;
  }

  self->cxx.reset(new bob::learn::em::PLDATrainer(*o->cxx));
  return 0;
}


static int PyBobLearnEMPLDATrainer_init_bool(PyBobLearnEMPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {

  char** kwlist = PLDATrainer_doc.kwlist(0);
145
  PyObject* use_sum_second_order = Py_False;
146
147

  //Parsing the input argments
148
  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!", kwlist, &PyBool_Type, &use_sum_second_order))
149
150
151
152
153
154
155
156
157
158
159
160
161
    return -1;

  self->cxx.reset(new bob::learn::em::PLDATrainer(f(use_sum_second_order)));
  return 0;
}


static int PyBobLearnEMPLDATrainer_init(PyBobLearnEMPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
  BOB_TRY

  // get the number of command line arguments
  int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);

162
163
164
  if(nargs==0)
    return PyBobLearnEMPLDATrainer_init_bool(self, args, kwargs);
  else if(nargs==1){
165
166
167
168
169
170
171
172
173
    //Reading the input argument
    PyObject* arg = 0;
    if (PyTuple_Size(args))
      arg = PyTuple_GET_ITEM(args, 0);
    else {
      PyObject* tmp = PyDict_Values(kwargs);
      auto tmp_ = make_safe(tmp);
      arg = PyList_GET_ITEM(tmp, 0);
    }
174

175
176
177
178
179
180
181
182
183
184
185
186
    if(PyBobLearnEMPLDATrainer_Check(arg))
      // If the constructor input is PLDATrainer object
      return PyBobLearnEMPLDATrainer_init_copy(self, args, kwargs);
    else
      return PyBobLearnEMPLDATrainer_init_bool(self, args, kwargs);
  }
  else{
    PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires only 0 or 1 argument, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs);
    PLDATrainer_doc.print_usage();
    return -1;
  }

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
187
  BOB_CATCH_MEMBER("cannot create PLDATrainer", -1)
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
  return 0;
}


static void PyBobLearnEMPLDATrainer_delete(PyBobLearnEMPLDATrainerObject* self) {
  self->cxx.reset();
  Py_TYPE(self)->tp_free((PyObject*)self);
}


int PyBobLearnEMPLDATrainer_Check(PyObject* o) {
  return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnEMPLDATrainer_Type));
}


static PyObject* PyBobLearnEMPLDATrainer_RichCompare(PyBobLearnEMPLDATrainerObject* self, PyObject* other, int op) {
  BOB_TRY

  if (!PyBobLearnEMPLDATrainer_Check(other)) {
    PyErr_Format(PyExc_TypeError, "cannot compare `%s' with `%s'", Py_TYPE(self)->tp_name, Py_TYPE(other)->tp_name);
    return 0;
  }
  auto other_ = reinterpret_cast<PyBobLearnEMPLDATrainerObject*>(other);
  switch (op) {
    case Py_EQ:
      if (*self->cxx==*other_->cxx) Py_RETURN_TRUE; else Py_RETURN_FALSE;
    case Py_NE:
      if (*self->cxx==*other_->cxx) Py_RETURN_FALSE; else Py_RETURN_TRUE;
    default:
      Py_INCREF(Py_NotImplemented);
      return Py_NotImplemented;
  }
  BOB_CATCH_MEMBER("cannot compare PLDATrainer objects", 0)
}


/******************************************************************/
/************ Variables Section ***********************************/
/******************************************************************/

static auto z_second_order = bob::extension::VariableDoc(
  "z_second_order",
  "array_like <float, 3D>",
  "",
  ""
);
PyObject* PyBobLearnEMPLDATrainer_get_z_second_order(PyBobLearnEMPLDATrainerObject* self, void*){
  BOB_TRY
  //return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getZSecondOrder());
  return vector_as_list(self->cxx->getZSecondOrder());
  BOB_CATCH_MEMBER("z_second_order could not be read", 0)
}


static auto z_second_order_sum = bob::extension::VariableDoc(
  "z_second_order_sum",
  "array_like <float, 2D>",
  "",
  ""
);
PyObject* PyBobLearnEMPLDATrainer_get_z_second_order_sum(PyBobLearnEMPLDATrainerObject* self, void*){
  BOB_TRY
  return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getZSecondOrderSum());
  BOB_CATCH_MEMBER("z_second_order_sum could not be read", 0)
}


static auto z_first_order = bob::extension::VariableDoc(
  "z_first_order",
  "array_like <float, 2D>",
  "",
  ""
);
PyObject* PyBobLearnEMPLDATrainer_get_z_first_order(PyBobLearnEMPLDATrainerObject* self, void*){
  BOB_TRY
  //return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getZFirstOrder());
  return vector_as_list(self->cxx->getZFirstOrder());
  BOB_CATCH_MEMBER("z_first_order could not be read", 0)
}


/***** init_f_method *****/
static auto init_f_method = bob::extension::VariableDoc(
  "init_f_method",
  "str",
  "The method used for the initialization of :math:`$F$`.",
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
274
  "Possible values are: ('RANDOM_F', 'BETWEEN_SCATTER')"
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
);
PyObject* PyBobLearnEMPLDATrainer_getFMethod(PyBobLearnEMPLDATrainerObject* self, void*) {
  BOB_TRY
  return Py_BuildValue("s", FMethod2string(self->cxx->getInitFMethod()).c_str());
  BOB_CATCH_MEMBER("init_f_method method could not be read", 0)
}
int PyBobLearnEMPLDATrainer_setFMethod(PyBobLearnEMPLDATrainerObject* self, PyObject* value, void*) {
  BOB_TRY

  if (!PyString_Check(value)){
    PyErr_Format(PyExc_RuntimeError, "%s %s expects an str", Py_TYPE(self)->tp_name, init_f_method.name());
    return -1;
  }
  self->cxx->setInitFMethod(string2FMethod(PyString_AS_STRING(value)));

  return 0;
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
291
  BOB_CATCH_MEMBER("init_f_method method could not be set", -1)
292
293
294
295
296
297
298
299
}


/***** init_g_method *****/
static auto init_g_method = bob::extension::VariableDoc(
  "init_g_method",
  "str",
  "The method used for the initialization of :math:`$G$`.",
300
  "Possible values are: ('RANDOM_G', 'WITHIN_SCATTER')"
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
);
PyObject* PyBobLearnEMPLDATrainer_getGMethod(PyBobLearnEMPLDATrainerObject* self, void*) {
  BOB_TRY
  return Py_BuildValue("s", GMethod2string(self->cxx->getInitGMethod()).c_str());
  BOB_CATCH_MEMBER("init_g_method method could not be read", 0)
}
int PyBobLearnEMPLDATrainer_setGMethod(PyBobLearnEMPLDATrainerObject* self, PyObject* value, void*) {
  BOB_TRY

  if (!PyString_Check(value)){
    PyErr_Format(PyExc_RuntimeError, "%s %s expects an str", Py_TYPE(self)->tp_name, init_g_method.name());
    return -1;
  }
  self->cxx->setInitGMethod(string2GMethod(PyString_AS_STRING(value)));

  return 0;
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
317
  BOB_CATCH_MEMBER("init_g_method method could not be set", -1)
318
319
320
321
322
323
324
}

/***** init_sigma_method *****/
static auto init_sigma_method = bob::extension::VariableDoc(
  "init_sigma_method",
  "str",
  "The method used for the initialization of :math:`$\\Sigma$`.",
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
325
  "Possible values are: ('RANDOM_SIGMA', 'VARIANCE_G', 'CONSTANT', 'VARIANCE_DATA')"
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
);
PyObject* PyBobLearnEMPLDATrainer_getSigmaMethod(PyBobLearnEMPLDATrainerObject* self, void*) {
  BOB_TRY
  return Py_BuildValue("s", SigmaMethod2string(self->cxx->getInitSigmaMethod()).c_str());
  BOB_CATCH_MEMBER("init_sigma_method method could not be read", 0)
}
int PyBobLearnEMPLDATrainer_setSigmaMethod(PyBobLearnEMPLDATrainerObject* self, PyObject* value, void*) {
  BOB_TRY

  if (!PyString_Check(value)){
    PyErr_Format(PyExc_RuntimeError, "%s %s expects an str", Py_TYPE(self)->tp_name, init_sigma_method.name());
    return -1;
  }
  self->cxx->setInitSigmaMethod(string2SigmaMethod(PyString_AS_STRING(value)));

  return 0;
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
342
  BOB_CATCH_MEMBER("init_sigma_method method could not be set", -1)
343
344
345
}


346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
static auto use_sum_second_order = bob::extension::VariableDoc(
  "use_sum_second_order",
  "bool",
  "Tells whether the second order statistics are stored during the training procedure, or only their sum.",
  ""
);
PyObject* PyBobLearnEMPLDATrainer_getUseSumSecondOrder(PyBobLearnEMPLDATrainerObject* self, void*){
  BOB_TRY
  return Py_BuildValue("O",self->cxx->getUseSumSecondOrder()?Py_True:Py_False);
  BOB_CATCH_MEMBER("use_sum_second_order could not be read", 0)
}
int PyBobLearnEMPLDATrainer_setUseSumSecondOrder(PyBobLearnEMPLDATrainerObject* self, PyObject* value, void*) {
  BOB_TRY

  if (!PyBool_Check(value)){
    PyErr_Format(PyExc_RuntimeError, "%s %s expects an str", Py_TYPE(self)->tp_name, use_sum_second_order.name());
    return -1;
  }
  self->cxx->setUseSumSecondOrder(f(value));

  return 0;
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
367
  BOB_CATCH_MEMBER("use_sum_second_order method could not be set", -1)
368
369
370
371
}



372
static PyGetSetDef PyBobLearnEMPLDATrainer_getseters[] = {
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
  {
   z_first_order.name(),
   (getter)PyBobLearnEMPLDATrainer_get_z_first_order,
   0,
   z_first_order.doc(),
   0
  },
  {
   z_second_order_sum.name(),
   (getter)PyBobLearnEMPLDATrainer_get_z_second_order_sum,
   0,
   z_second_order_sum.doc(),
   0
  },
  {
   z_second_order.name(),
   (getter)PyBobLearnEMPLDATrainer_get_z_second_order,
   0,
   z_second_order.doc(),
   0
  },
  {
   init_f_method.name(),
   (getter)PyBobLearnEMPLDATrainer_getFMethod,
   (setter)PyBobLearnEMPLDATrainer_setFMethod,
   init_f_method.doc(),
   0
  },
  {
   init_g_method.name(),
   (getter)PyBobLearnEMPLDATrainer_getGMethod,
   (setter)PyBobLearnEMPLDATrainer_setGMethod,
   init_g_method.doc(),
   0
  },
  {
   init_sigma_method.name(),
   (getter)PyBobLearnEMPLDATrainer_getSigmaMethod,
   (setter)PyBobLearnEMPLDATrainer_setSigmaMethod,
   init_sigma_method.doc(),
   0
414
415
416
417
418
419
420
421
  },
  {
   use_sum_second_order.name(),
   (getter)PyBobLearnEMPLDATrainer_getUseSumSecondOrder,
   (setter)PyBobLearnEMPLDATrainer_setUseSumSecondOrder,
   use_sum_second_order.doc(),
   0
  },
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
  {0}  // Sentinel
};


/******************************************************************/
/************ Functions Section ***********************************/
/******************************************************************/

/*** initialize ***/
static auto initialize = bob::extension::FunctionDoc(
  "initialize",
  "Initialization before the EM steps",
  "",
  true
)
437
.add_prototype("plda_base, data, [rng]")
438
.add_parameter("plda_base", ":py:class:`bob.learn.em.PLDABase`", "PLDAMachine Object")
439
440
.add_parameter("data", "list", "")
.add_parameter("rng", ":py:class:`bob.core.random.mt19937`", "The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop.");
441
442
443
444
445
446
447
448
static PyObject* PyBobLearnEMPLDATrainer_initialize(PyBobLearnEMPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
  BOB_TRY

  /* Parses input arguments in a single shot */
  char** kwlist = initialize.kwlist(0);

  PyBobLearnEMPLDABaseObject* plda_base = 0;
  PyObject* data = 0;
449
  PyBoostMt19937Object* rng = 0;
450

451
452
453
  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!|O!", kwlist, &PyBobLearnEMPLDABase_Type, &plda_base,
                                                                 &PyList_Type, &data,
                                                                 &PyBoostMt19937_Type, &rng)) return 0;
454
455

  std::vector<blitz::Array<double,2> > data_vector;
456
457
  if(list_as_vector(data ,data_vector)==0){
    if(rng){
458
      self->cxx->setRng(rng->rng);
459
460
    }

461
    self->cxx->initialize(*plda_base->cxx, data_vector);
462
  }
463
464
  else
    return 0;
465
466
467
468
469
470
471
472
473

  BOB_CATCH_MEMBER("cannot perform the initialize method", 0)

  Py_RETURN_NONE;
}


/*** e_step ***/
static auto e_step = bob::extension::FunctionDoc(
474
  "e_step",
475
  "Expectation step before the EM steps",
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
  "",
  true
)
.add_prototype("plda_base,data")
.add_parameter("plda_base", ":py:class:`bob.learn.em.PLDABase`", "PLDAMachine Object")
.add_parameter("data", "list", "");
static PyObject* PyBobLearnEMPLDATrainer_e_step(PyBobLearnEMPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
  BOB_TRY

  /* Parses input arguments in a single shot */
  char** kwlist = e_step.kwlist(0);

  PyBobLearnEMPLDABaseObject* plda_base = 0;
  PyObject* data = 0;

  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!", kwlist, &PyBobLearnEMPLDABase_Type, &plda_base,
492
                                                                 &PyList_Type, &data)) return 0;
493
494
495
496

  std::vector<blitz::Array<double,2> > data_vector;
  if(list_as_vector(data ,data_vector)==0)
    self->cxx->eStep(*plda_base->cxx, data_vector);
497
498
  else
    return 0;
499
500
501
502
503
504
505
506
507

  BOB_CATCH_MEMBER("cannot perform the e_step method", 0)

  Py_RETURN_NONE;
}


/*** m_step ***/
static auto m_step = bob::extension::FunctionDoc(
508
  "m_step",
509
  "Maximization step ",
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
  "",
  true
)
.add_prototype("plda_base,data")
.add_parameter("plda_base", ":py:class:`bob.learn.em.PLDABase`", "PLDAMachine Object")
.add_parameter("data", "list", "");
static PyObject* PyBobLearnEMPLDATrainer_m_step(PyBobLearnEMPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
  BOB_TRY

  /* Parses input arguments in a single shot */
  char** kwlist = m_step.kwlist(0);

  PyBobLearnEMPLDABaseObject* plda_base = 0;
  PyObject* data = 0;

  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!", kwlist, &PyBobLearnEMPLDABase_Type, &plda_base,
526
                                                                 &PyList_Type, &data)) return 0;
527
528
529
530

  std::vector<blitz::Array<double,2> > data_vector;
  if(list_as_vector(data ,data_vector)==0)
    self->cxx->mStep(*plda_base->cxx, data_vector);
531
532
  else
    return 0;
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559

  BOB_CATCH_MEMBER("cannot perform the m_step method", 0)

  Py_RETURN_NONE;
}


/*** finalize ***/
static auto finalize = bob::extension::FunctionDoc(
  "finalize",
  "finalize before the EM steps",
  "",
  true
)
.add_prototype("plda_base,data")
.add_parameter("plda_base", ":py:class:`bob.learn.em.PLDABase`", "PLDAMachine Object")
.add_parameter("data", "list", "");
static PyObject* PyBobLearnEMPLDATrainer_finalize(PyBobLearnEMPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
  BOB_TRY

  /* Parses input arguments in a single shot */
  char** kwlist = finalize.kwlist(0);

  PyBobLearnEMPLDABaseObject* plda_base = 0;
  PyObject* data = 0;

  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!", kwlist, &PyBobLearnEMPLDABase_Type, &plda_base,
560
                                                                 &PyList_Type, &data)) return 0;
561
562
563
564

  std::vector<blitz::Array<double,2> > data_vector;
  if(list_as_vector(data ,data_vector)==0)
    self->cxx->finalize(*plda_base->cxx, data_vector);
565
566
  else
    return 0;
567
568
569
570
571
572
573
574

  BOB_CATCH_MEMBER("cannot perform the finalize method", 0)

  Py_RETURN_NONE;
}



575
576
577
/*** enroll ***/
static auto enroll = bob::extension::FunctionDoc(
  "enroll",
578
579
580
581
582
583
584
  "Main procedure for enrolling a PLDAMachine",
  "",
  true
)
.add_prototype("plda_machine,data")
.add_parameter("plda_machine", ":py:class:`bob.learn.em.PLDAMachine`", "PLDAMachine Object")
.add_parameter("data", "list", "");
585
static PyObject* PyBobLearnEMPLDATrainer_enroll(PyBobLearnEMPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
586
587
588
  BOB_TRY

  /* Parses input arguments in a single shot */
589
  char** kwlist = enroll.kwlist(0);
590
591
592
593
594

  PyBobLearnEMPLDAMachineObject* plda_machine = 0;
  PyBlitzArrayObject* data = 0;

  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O&", kwlist, &PyBobLearnEMPLDAMachine_Type, &plda_machine,
595
                                                                 &PyBlitzArray_Converter, &data)) return 0;
596
597

  auto data_ = make_safe(data);
598
  self->cxx->enroll(*plda_machine->cxx, *PyBlitzArrayCxx_AsBlitz<double,2>(data));
599

600
  BOB_CATCH_MEMBER("cannot perform the enroll method", 0)
601
602
603
604
605
606
607
608

  Py_RETURN_NONE;
}


/*** is_similar_to ***/
static auto is_similar_to = bob::extension::FunctionDoc(
  "is_similar_to",
609

610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
  "Compares this PLDATrainer with the ``other`` one to be approximately the same.",
  "The optional values ``r_epsilon`` and ``a_epsilon`` refer to the "
  "relative and absolute precision for the ``weights``, ``biases`` "
  "and any other values internal to this machine."
)
.add_prototype("other, [r_epsilon], [a_epsilon]","output")
.add_parameter("other", ":py:class:`bob.learn.em.PLDAMachine`", "A PLDAMachine object to be compared.")
.add_parameter("r_epsilon", "float", "Relative precision.")
.add_parameter("a_epsilon", "float", "Absolute precision.")
.add_return("output","bool","True if it is similar, otherwise false.");
static PyObject* PyBobLearnEMPLDATrainer_IsSimilarTo(PyBobLearnEMPLDATrainerObject* self, PyObject* args, PyObject* kwds) {

  /* Parses input arguments in a single shot */
  char** kwlist = is_similar_to.kwlist(0);

  //PyObject* other = 0;
  PyBobLearnEMPLDATrainerObject* other = 0;
  double r_epsilon = 1.e-5;
  double a_epsilon = 1.e-8;

  if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|dd", kwlist,
        &PyBobLearnEMPLDATrainer_Type, &other,
        &r_epsilon, &a_epsilon)){

634
635
        is_similar_to.print_usage();
        return 0;
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
  }

  if (self->cxx->is_similar_to(*other->cxx, r_epsilon, a_epsilon))
    Py_RETURN_TRUE;
  else
    Py_RETURN_FALSE;
}



static PyMethodDef PyBobLearnEMPLDATrainer_methods[] = {
  {
    initialize.name(),
    (PyCFunction)PyBobLearnEMPLDATrainer_initialize,
    METH_VARARGS|METH_KEYWORDS,
    initialize.doc()
  },
  {
    e_step.name(),
    (PyCFunction)PyBobLearnEMPLDATrainer_e_step,
    METH_VARARGS|METH_KEYWORDS,
    e_step.doc()
  },
  {
    m_step.name(),
    (PyCFunction)PyBobLearnEMPLDATrainer_m_step,
    METH_VARARGS|METH_KEYWORDS,
    m_step.doc()
  },
  {
    finalize.name(),
    (PyCFunction)PyBobLearnEMPLDATrainer_finalize,
    METH_VARARGS|METH_KEYWORDS,
    finalize.doc()
670
  },
671
  {
672
673
    enroll.name(),
    (PyCFunction)PyBobLearnEMPLDATrainer_enroll,
674
    METH_VARARGS|METH_KEYWORDS,
675
    enroll.doc()
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
  },
  {
    is_similar_to.name(),
    (PyCFunction)PyBobLearnEMPLDATrainer_IsSimilarTo,
    METH_VARARGS|METH_KEYWORDS,
    is_similar_to.doc()
  },
  {0} /* Sentinel */
};


/******************************************************************/
/************ Module Section **************************************/
/******************************************************************/

// Define the Gaussian type struct; will be initialized later
PyTypeObject PyBobLearnEMPLDATrainer_Type = {
  PyVarObject_HEAD_INIT(0,0)
  0
};

bool init_BobLearnEMPLDATrainer(PyObject* module)
{
  // initialize the type struct
  PyBobLearnEMPLDATrainer_Type.tp_name      = PLDATrainer_doc.name();
  PyBobLearnEMPLDATrainer_Type.tp_basicsize = sizeof(PyBobLearnEMPLDATrainerObject);
  PyBobLearnEMPLDATrainer_Type.tp_flags     = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;//Enable the class inheritance;
  PyBobLearnEMPLDATrainer_Type.tp_doc       = PLDATrainer_doc.doc();

  // set the functions
  PyBobLearnEMPLDATrainer_Type.tp_new          = PyType_GenericNew;
  PyBobLearnEMPLDATrainer_Type.tp_init         = reinterpret_cast<initproc>(PyBobLearnEMPLDATrainer_init);
  PyBobLearnEMPLDATrainer_Type.tp_dealloc      = reinterpret_cast<destructor>(PyBobLearnEMPLDATrainer_delete);
  PyBobLearnEMPLDATrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnEMPLDATrainer_RichCompare);
  PyBobLearnEMPLDATrainer_Type.tp_methods      = PyBobLearnEMPLDATrainer_methods;
  PyBobLearnEMPLDATrainer_Type.tp_getset       = PyBobLearnEMPLDATrainer_getseters;
  //PyBobLearnEMPLDATrainer_Type.tp_call         = reinterpret_cast<ternaryfunc>(PyBobLearnEMPLDATrainer_compute_likelihood);


  // check that everything is fine
  if (PyType_Ready(&PyBobLearnEMPLDATrainer_Type) < 0) return false;

  // add the type to the module
  Py_INCREF(&PyBobLearnEMPLDATrainer_Type);
720
  return PyModule_AddObject(module, "PLDATrainer", (PyObject*)&PyBobLearnEMPLDATrainer_Type) >= 0;
721
}