From 0be70fafd110ef15573d3e72e3c0d5ebc176053f Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Thu, 5 Feb 2015 19:30:26 +0100
Subject: [PATCH] Binding linear scoring

---
 bob/learn/misc/__plda_trainer__.py            |   8 +-
 .../misc/include/bob.learn.misc/PLDATrainer.h |  13 +
 bob/learn/misc/kmeans_trainer.cpp             |   2 +-
 bob/learn/misc/linear_scoring.cpp             | 179 +++++++++++++
 bob/learn/misc/main.cpp                       |   3 +
 bob/learn/misc/main.h                         |   2 -
 bob/learn/misc/plda_machine.cpp               |   2 +-
 bob/learn/misc/plda_trainer.cpp               | 246 +++++++++++++++++-
 bob/learn/misc/test_plda_trainer.py           |  32 ++-
 bob/learn/misc/ztnorm.cpp                     |   5 -
 setup.py                                      |   2 +
 11 files changed, 464 insertions(+), 30 deletions(-)
 create mode 100644 bob/learn/misc/linear_scoring.cpp

diff --git a/bob/learn/misc/__plda_trainer__.py b/bob/learn/misc/__plda_trainer__.py
index a48e92c..0fc0ee3 100644
--- a/bob/learn/misc/__plda_trainer__.py
+++ b/bob/learn/misc/__plda_trainer__.py
@@ -11,7 +11,7 @@ import numpy
 # define the class
 class PLDATrainer (_PLDATrainer):
 
-  def __init__(self, max_iterations=10, use_sum_second_order=True):
+  def __init__(self, max_iterations=10, use_sum_second_order=False):
     """
     :py:class:`bob.learn.misc.PLDATrainer` constructor
 
@@ -39,10 +39,10 @@ class PLDATrainer (_PLDATrainer):
       
     for i in range(self._max_iterations):
       #eStep
-      self.eStep(plda_base, data);
+      self.e_step(plda_base, data);
       #mStep
-      self.mStep(plda_base);
-    self.finalize(plda_base);
+      self.m_step(plda_base, data);
+    self.finalize(plda_base, data);
 
 
 
diff --git a/bob/learn/misc/include/bob.learn.misc/PLDATrainer.h b/bob/learn/misc/include/bob.learn.misc/PLDATrainer.h
index 3439490..3323083 100644
--- a/bob/learn/misc/include/bob.learn.misc/PLDATrainer.h
+++ b/bob/learn/misc/include/bob.learn.misc/PLDATrainer.h
@@ -217,6 +217,19 @@ class PLDATrainer
      */
     void enrol(bob::learn::misc::PLDAMachine& plda_machine,
       const blitz::Array<double,2>& ar) const;
+      
+      
+    /**
+     * @brief Sets the Random Number Generator
+     */
+    void setRng(const boost::shared_ptr<boost::mt19937> rng)
+    { m_rng = rng; }
+
+    /**
+     * @brief Gets the Random Number Generator
+     */
+    const boost::shared_ptr<boost::mt19937> getRng() const
+    { return m_rng; }      
 
   private:
   
diff --git a/bob/learn/misc/kmeans_trainer.cpp b/bob/learn/misc/kmeans_trainer.cpp
index cbaac80..c31e3df 100644
--- a/bob/learn/misc/kmeans_trainer.cpp
+++ b/bob/learn/misc/kmeans_trainer.cpp
@@ -282,7 +282,7 @@ int PyBobLearnMiscKMeansTrainer_setRng(PyBobLearnMiscKMeansTrainerObject* self,
   BOB_TRY
 
   if (!PyBoostMt19937_Check(value)){
-    PyErr_Format(PyExc_RuntimeError, "%s %s expects an PyBoostMt19937_Check", Py_TYPE(self)->tp_name, average_min_distance.name());
+    PyErr_Format(PyExc_RuntimeError, "%s %s expects an PyBoostMt19937_Check", Py_TYPE(self)->tp_name, rng.name());
     return -1;
   }
 
diff --git a/bob/learn/misc/linear_scoring.cpp b/bob/learn/misc/linear_scoring.cpp
new file mode 100644
index 0000000..270dde8
--- /dev/null
+++ b/bob/learn/misc/linear_scoring.cpp
@@ -0,0 +1,179 @@
+/**
+ * @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+ * @date Wed 05 Feb 16:10:48 2015
+ *
+ * @brief Python API for bob::learn::em
+ *
+ * Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
+ */
+
+#include "main.h"
+
+/*Convert a PyObject to a a list of GMMStats*/
+//template<class R, class P1, class P2>
+static int extract_gmmstats_list(PyObject *list,
+                             std::vector<boost::shared_ptr<const bob::learn::misc::GMMStats> >& training_data)
+{
+  for (int i=0; i<PyList_GET_SIZE(list); i++){
+  
+    PyBobLearnMiscGMMStatsObject* stats;
+    if (!PyArg_Parse(PyList_GetItem(list, i), "O!", &PyBobLearnMiscGMMStats_Type, &stats)){
+      PyErr_Format(PyExc_RuntimeError, "Expected GMMStats objects");
+      return -1;
+    }
+    training_data.push_back(stats->cxx);
+  }
+  return 0;
+}
+
+static int extract_gmmmachine_list(PyObject *list,
+                             std::vector<boost::shared_ptr<const bob::learn::misc::GMMMachine> >& training_data)
+{
+  for (int i=0; i<PyList_GET_SIZE(list); i++){
+  
+    PyBobLearnMiscGMMMachineObject* stats;
+    if (!PyArg_Parse(PyList_GetItem(list, i), "O!", &PyBobLearnMiscGMMMachine_Type, &stats)){
+      PyErr_Format(PyExc_RuntimeError, "Expected GMMMachine objects");
+      return -1;
+    }
+    training_data.push_back(stats->cxx);
+  }
+  return 0;
+}
+
+
+
+/*Convert a PyObject to a list of blitz Array*/
+template <int N>
+int extract_array_list(PyObject* list, std::vector<blitz::Array<double,N> >& vec)
+{
+  for (int i=0; i<PyList_GET_SIZE(list); i++)
+  {
+    PyBlitzArrayObject* blitz_object; 
+    if (!PyArg_Parse(PyList_GetItem(list, i), "O&", &PyBlitzArray_Converter, &blitz_object)){
+      PyErr_Format(PyExc_RuntimeError, "Expected numpy array object");
+      return -1;
+    }
+    auto blitz_object_ = make_safe(blitz_object);
+    vec.push_back(*PyBlitzArrayCxx_AsBlitz<double,N>(blitz_object));
+  }
+  return 0;
+}
+
+/* converts PyObject to bool and returns false if object is NULL */
+static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;}
+
+
+/*** linear_scoring ***/
+static auto linear_scoring = bob::extension::FunctionDoc(
+  "linear_scoring",
+  "",
+  0,
+  true
+)
+.add_prototype("models, ubm_mean, ubm_variance, test_stats, test_channelOffset, frame_length_normalisation", "output")
+.add_parameter("models", "", "")
+.add_parameter("ubm", "", "")
+.add_parameter("test_stats", "", "")
+.add_parameter("test_channelOffset", "", "")
+.add_parameter("frame_length_normalisation", "bool", "")
+.add_return("output","array_like<float,2>","Score");
+static PyObject* PyBobLearnMisc_linear_scoring(PyObject*, PyObject* args, PyObject* kwargs) {
+
+  char** kwlist = linear_scoring.kwlist(0);
+    
+  //Cheking the number of arguments
+  int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
+
+  switch(nargs){
+  
+    //Read a list of GMM
+    case 5:{
+
+      PyObject* gmm_list_o                 = 0;
+      PyBobLearnMiscGMMMachineObject* ubm  = 0;
+      PyObject* stats_list_o               = 0;
+      PyObject* channel_offset_list_o      = 0;
+      PyObject* frame_length_normalisation = 0;
+
+      if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!O!O!O!", kwlist, &PyList_Type, &gmm_list_o,
+                                                                       &PyBobLearnMiscGMMMachine_Type, &ubm,
+                                                                       &PyList_Type, &stats_list_o,
+                                                                       &PyList_Type, &channel_offset_list_o,
+                                                                       &PyBool_Type, frame_length_normalisation)){
+       linear_scoring.print_usage();
+       Py_RETURN_NONE;
+      }
+
+      std::vector<boost::shared_ptr<const bob::learn::misc::GMMStats> > stats_list;
+      if(extract_gmmstats_list(stats_list_o ,stats_list)!=0)
+        Py_RETURN_NONE;
+
+      std::vector<boost::shared_ptr<const bob::learn::misc::GMMMachine> > gmm_list;
+      if(extract_gmmmachine_list(gmm_list_o ,gmm_list)!=0)
+        Py_RETURN_NONE;
+
+      std::vector<blitz::Array<double,2> > channel_offset_list;
+      if(extract_array_list(channel_offset_list_o ,channel_offset_list)!=0)
+        Py_RETURN_NONE;
+
+
+      blitz::Array<double, 2> scores = blitz::Array<double, 2>(gmm_list.size(), stats_list.size());      
+      bob::learn::misc::linearScoring(gmm_list, *ubm->cxx, stats_list, channel_offset_list, f(frame_length_normalisation),scores);
+
+      
+      return PyBlitzArrayCxx_AsConstNumpy(scores);
+    
+    }
+    default:{
+      PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - linear_scoring requires 5 or 6 arguments, but you provided %d (see help)", nargs);
+      linear_scoring.print_usage();
+      Py_RETURN_NONE;
+    }  
+  }
+  /*
+  
+  
+  PyBlitzArrayObject *rawscores_probes_vs_models_o, *rawscores_zprobes_vs_models_o, *rawscores_probes_vs_tmodels_o, 
+  *rawscores_zprobes_vs_tmodels_o, *mask_zprobes_vs_tmodels_istruetrial_o;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&O&O&O&|O&", kwlist, &PyBlitzArray_Converter, &rawscores_probes_vs_models_o,
+                                                                       &PyBlitzArray_Converter, &rawscores_zprobes_vs_models_o,
+                                                                       &PyBlitzArray_Converter, &rawscores_probes_vs_tmodels_o,
+                                                                       &PyBlitzArray_Converter, &rawscores_zprobes_vs_tmodels_o,
+                                                                       &PyBlitzArray_Converter, &mask_zprobes_vs_tmodels_istruetrial_o)){
+    zt_norm.print_usage();
+    Py_RETURN_NONE;
+  }
+
+  // get the number of command line arguments
+  auto rawscores_probes_vs_models_          = make_safe(rawscores_probes_vs_models_o);
+  auto rawscores_zprobes_vs_models_         = make_safe(rawscores_zprobes_vs_models_o);
+  auto rawscores_probes_vs_tmodels_         = make_safe(rawscores_probes_vs_tmodels_o);
+  auto rawscores_zprobes_vs_tmodels_        = make_safe(rawscores_zprobes_vs_tmodels_o);
+  //auto mask_zprobes_vs_tmodels_istruetrial_ = make_safe(mask_zprobes_vs_tmodels_istruetrial_o);
+
+  blitz::Array<double,2>  rawscores_probes_vs_models = *PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_probes_vs_models_o);
+  blitz::Array<double,2> normalized_scores = blitz::Array<double,2>(rawscores_probes_vs_models.extent(0), rawscores_probes_vs_models.extent(1));
+
+  int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
+
+  if(nargs==4)
+    bob::learn::misc::ztNorm(*PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_probes_vs_models_o),
+                             *PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_zprobes_vs_models_o),
+                             *PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_probes_vs_tmodels_o),
+                             *PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_zprobes_vs_tmodels_o),
+                             normalized_scores);
+  else
+    bob::learn::misc::ztNorm(*PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_probes_vs_models_o), 
+                             *PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_zprobes_vs_models_o), 
+                             *PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_probes_vs_tmodels_o), 
+                             *PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_zprobes_vs_tmodels_o), 
+                             *PyBlitzArrayCxx_AsBlitz<bool,2>(mask_zprobes_vs_tmodels_istruetrial_o),
+                             normalized_scores);
+
+  return PyBlitzArrayCxx_AsConstNumpy(normalized_scores);
+  */
+
+}
+
diff --git a/bob/learn/misc/main.cpp b/bob/learn/misc/main.cpp
index c16df1e..254c8ab 100644
--- a/bob/learn/misc/main.cpp
+++ b/bob/learn/misc/main.cpp
@@ -9,6 +9,9 @@
 #undef NO_IMPORT_ARRAY
 #endif
 #include "main.h"
+#include "ztnorm.cpp"
+#include "linear_scoring.cpp"
+
 
 static PyMethodDef module_methods[] = {
   {
diff --git a/bob/learn/misc/main.h b/bob/learn/misc/main.h
index 36b2eed..42f5f12 100644
--- a/bob/learn/misc/main.h
+++ b/bob/learn/misc/main.h
@@ -50,8 +50,6 @@
 #include <bob.learn.misc/ZTNorm.h>
 
 
-#include "ztnorm.cpp"
-
 
 #if PY_VERSION_HEX >= 0x03000000
 #define PyInt_Check PyLong_Check
diff --git a/bob/learn/misc/plda_machine.cpp b/bob/learn/misc/plda_machine.cpp
index 6bfdf0b..d5cd1c8 100644
--- a/bob/learn/misc/plda_machine.cpp
+++ b/bob/learn/misc/plda_machine.cpp
@@ -643,7 +643,7 @@ static PyObject* PyBobLearnMiscPLDAMachine_computeLogLikelihood(PyBobLearnMiscPL
   char** kwlist = compute_log_likelihood.kwlist(0);
 
   PyBlitzArrayObject* samples;
-  PyObject* with_enrolled_samples = 0;
+  PyObject* with_enrolled_samples = Py_True;
   
   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&|O!", kwlist, &PyBlitzArray_Converter, &samples,
                                                                  &PyBool_Type, &with_enrolled_samples)) Py_RETURN_NONE;
diff --git a/bob/learn/misc/plda_trainer.cpp b/bob/learn/misc/plda_trainer.cpp
index 902d428..2f76dc3 100644
--- a/bob/learn/misc/plda_trainer.cpp
+++ b/bob/learn/misc/plda_trainer.cpp
@@ -10,9 +10,50 @@
 #include "main.h"
 #include <boost/make_shared.hpp>
 
-/******************************************************************/
-/************ Constructor Section *********************************/
-/******************************************************************/
+//Defining maps for each initializatio method
+static const std::map<std::string, bob::learn::misc::PLDATrainer::InitFMethod> FMethod = {{"RANDOM_F",  bob::learn::misc::PLDATrainer::RANDOM_F}, {"BETWEEN_SCATTER", bob::learn::misc::PLDATrainer::BETWEEN_SCATTER}};
+
+static const std::map<std::string, bob::learn::misc::PLDATrainer::InitGMethod> GMethod = {{"RANDOM_G",  bob::learn::misc::PLDATrainer::RANDOM_G}, {"WITHIN_SCATTER", bob::learn::misc::PLDATrainer::WITHIN_SCATTER}};
+
+static const std::map<std::string, bob::learn::misc::PLDATrainer::InitSigmaMethod> SigmaMethod = {{"RANDOM_SIGMA",  bob::learn::misc::PLDATrainer::RANDOM_SIGMA}, {"VARIANCE_G", bob::learn::misc::PLDATrainer::VARIANCE_G}, {"CONSTANT", bob::learn::misc::PLDATrainer::CONSTANT}, {"VARIANCE_DATA", bob::learn::misc::PLDATrainer::VARIANCE_DATA}};
+
+
+
+//String to type
+static inline bob::learn::misc::PLDATrainer::InitFMethod string2FMethod(const std::string& o){
+  auto it = FMethod.find(o);
+  if (it == FMethod.end()) throw std::runtime_error("The given FMethod '" + o + "' is not known; choose one of ('RANDOM_F','BETWEEN_SCATTER')");
+  else return it->second;
+}
+
+static inline bob::learn::misc::PLDATrainer::InitGMethod string2GMethod(const std::string& o){
+  auto it = GMethod.find(o);
+  if (it == GMethod.end()) throw std::runtime_error("The given GMethod '" + o + "' is not known; choose one of ('RANDOM_G','WITHIN_SCATTER')");
+  else return it->second;
+}
+
+static inline bob::learn::misc::PLDATrainer::InitSigmaMethod string2SigmaMethod(const std::string& o){
+  auto it = SigmaMethod.find(o);
+  if (it == SigmaMethod.end()) throw std::runtime_error("The given SigmaMethod '" + o + "' is not known; choose one of ('RANDOM_SIGMA','VARIANCE_G', 'CONSTANT', 'VARIANCE_DATA')");
+  else return it->second;
+}
+
+//Type to string
+static inline const std::string& FMethod2string(bob::learn::misc::PLDATrainer::InitFMethod o){
+  for (auto it = FMethod.begin(); it != FMethod.end(); ++it) if (it->second == o) return it->first;
+  throw std::runtime_error("The given FMethod type is not known");
+}
+
+static inline const std::string& GMethod2string(bob::learn::misc::PLDATrainer::InitGMethod o){
+  for (auto it = GMethod.begin(); it != GMethod.end(); ++it) if (it->second == o) return it->first;
+  throw std::runtime_error("The given GMethod type is not known");
+}
+
+static inline const std::string& SigmaMethod2string(bob::learn::misc::PLDATrainer::InitSigmaMethod o){
+  for (auto it = SigmaMethod.begin(); it != SigmaMethod.end(); ++it) if (it->second == o) return it->first;
+  throw std::runtime_error("The given SigmaMethod type is not known");
+}
+
 
 static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;}  /* converts PyObject to bool and returns false if object is NULL */
 
@@ -46,6 +87,11 @@ static PyObject* vector_as_list(const std::vector<blitz::Array<double,N> >& vec)
 }
 
 
+/******************************************************************/
+/************ Constructor Section *********************************/
+/******************************************************************/
+
+
 static auto PLDATrainer_doc = bob::extension::ClassDoc(
   BOB_EXT_MODULE_PREFIX ".PLDATrainer",
   "This class can be used to train the :math:`$F$`, :math:`$G$ and "
@@ -208,6 +254,117 @@ PyObject* PyBobLearnMiscPLDATrainer_get_z_first_order(PyBobLearnMiscPLDATrainerO
 }
 
 
+/***** rng *****/
+static auto rng = bob::extension::VariableDoc(
+  "rng",
+  "str",
+  "The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop.",
+  ""
+);
+PyObject* PyBobLearnMiscPLDATrainer_getRng(PyBobLearnMiscPLDATrainerObject* self, void*) {
+  BOB_TRY
+  //Allocating the correspondent python object
+  
+  PyBoostMt19937Object* retval =
+    (PyBoostMt19937Object*)PyBoostMt19937_Type.tp_alloc(&PyBoostMt19937_Type, 0);
+
+  retval->rng = self->cxx->getRng().get();
+  return Py_BuildValue("O", retval);
+  BOB_CATCH_MEMBER("Rng method could not be read", 0)
+}
+int PyBobLearnMiscPLDATrainer_setRng(PyBobLearnMiscPLDATrainerObject* self, PyObject* value, void*) {
+  BOB_TRY
+
+  if (!PyBoostMt19937_Check(value)){
+    PyErr_Format(PyExc_RuntimeError, "%s %s expects an PyBoostMt19937_Check", Py_TYPE(self)->tp_name, rng.name());
+    return -1;
+  }
+
+  PyBoostMt19937Object* rng_object = 0;
+  PyArg_Parse(value, "O!", &PyBoostMt19937_Type, &rng_object);
+  self->cxx->setRng((boost::shared_ptr<boost::mt19937>)rng_object->rng);
+
+  return 0;
+  BOB_CATCH_MEMBER("Rng could not be set", 0)
+}
+
+
+/***** init_f_method *****/
+static auto init_f_method = bob::extension::VariableDoc(
+  "init_f_method",
+  "str",
+  "The method used for the initialization of :math:`$F$`.",
+  ""
+);
+PyObject* PyBobLearnMiscPLDATrainer_getFMethod(PyBobLearnMiscPLDATrainerObject* self, void*) {
+  BOB_TRY
+  return Py_BuildValue("s", FMethod2string(self->cxx->getInitFMethod()).c_str());
+  BOB_CATCH_MEMBER("init_f_method method could not be read", 0)
+}
+int PyBobLearnMiscPLDATrainer_setFMethod(PyBobLearnMiscPLDATrainerObject* self, PyObject* value, void*) {
+  BOB_TRY
+
+  if (!PyString_Check(value)){
+    PyErr_Format(PyExc_RuntimeError, "%s %s expects an str", Py_TYPE(self)->tp_name, init_f_method.name());
+    return -1;
+  }
+  self->cxx->setInitFMethod(string2FMethod(PyString_AS_STRING(value)));
+
+  return 0;
+  BOB_CATCH_MEMBER("init_f_method method could not be set", 0)
+}
+
+
+/***** init_g_method *****/
+static auto init_g_method = bob::extension::VariableDoc(
+  "init_g_method",
+  "str",
+  "The method used for the initialization of :math:`$G$`.",
+  ""
+);
+PyObject* PyBobLearnMiscPLDATrainer_getGMethod(PyBobLearnMiscPLDATrainerObject* self, void*) {
+  BOB_TRY
+  return Py_BuildValue("s", GMethod2string(self->cxx->getInitGMethod()).c_str());
+  BOB_CATCH_MEMBER("init_g_method method could not be read", 0)
+}
+int PyBobLearnMiscPLDATrainer_setGMethod(PyBobLearnMiscPLDATrainerObject* self, PyObject* value, void*) {
+  BOB_TRY
+
+  if (!PyString_Check(value)){
+    PyErr_Format(PyExc_RuntimeError, "%s %s expects an str", Py_TYPE(self)->tp_name, init_g_method.name());
+    return -1;
+  }
+  self->cxx->setInitGMethod(string2GMethod(PyString_AS_STRING(value)));
+
+  return 0;
+  BOB_CATCH_MEMBER("init_g_method method could not be set", 0)
+}
+
+
+/***** init_sigma_method *****/
+static auto init_sigma_method = bob::extension::VariableDoc(
+  "init_sigma_method",
+  "str",
+  "The method used for the initialization of :math:`$\\Sigma$`.",
+  ""
+);
+PyObject* PyBobLearnMiscPLDATrainer_getSigmaMethod(PyBobLearnMiscPLDATrainerObject* self, void*) {
+  BOB_TRY
+  return Py_BuildValue("s", SigmaMethod2string(self->cxx->getInitSigmaMethod()).c_str());
+  BOB_CATCH_MEMBER("init_sigma_method method could not be read", 0)
+}
+int PyBobLearnMiscPLDATrainer_setSigmaMethod(PyBobLearnMiscPLDATrainerObject* self, PyObject* value, void*) {
+  BOB_TRY
+
+  if (!PyString_Check(value)){
+    PyErr_Format(PyExc_RuntimeError, "%s %s expects an str", Py_TYPE(self)->tp_name, init_sigma_method.name());
+    return -1;
+  }
+  self->cxx->setInitSigmaMethod(string2SigmaMethod(PyString_AS_STRING(value)));
+
+  return 0;
+  BOB_CATCH_MEMBER("init_sigma_method method could not be set", 0)
+}
 
 
 static PyGetSetDef PyBobLearnMiscPLDATrainer_getseters[] = { 
@@ -232,7 +389,34 @@ static PyGetSetDef PyBobLearnMiscPLDATrainer_getseters[] = {
    z_second_order.doc(),
    0
   },
-
+  {
+   rng.name(),
+   (getter)PyBobLearnMiscPLDATrainer_getRng,
+   (setter)PyBobLearnMiscPLDATrainer_setRng,
+   rng.doc(),
+   0
+  },
+  {
+   init_f_method.name(),
+   (getter)PyBobLearnMiscPLDATrainer_getFMethod,
+   (setter)PyBobLearnMiscPLDATrainer_setFMethod,
+   init_f_method.doc(),
+   0
+  },
+  {
+   init_g_method.name(),
+   (getter)PyBobLearnMiscPLDATrainer_getGMethod,
+   (setter)PyBobLearnMiscPLDATrainer_setGMethod,
+   init_g_method.doc(),
+   0
+  },
+  {
+   init_sigma_method.name(),
+   (getter)PyBobLearnMiscPLDATrainer_getSigmaMethod,
+   (setter)PyBobLearnMiscPLDATrainer_setSigmaMethod,
+   init_sigma_method.doc(),
+   0
+  },  
   {0}  // Sentinel
 };
 
@@ -384,7 +568,7 @@ static PyObject* PyBobLearnMiscPLDATrainer_enrol(PyBobLearnMiscPLDATrainerObject
   BOB_TRY
 
   /* Parses input arguments in a single shot */
-  char** kwlist = finalize.kwlist(0);
+  char** kwlist = enrol.kwlist(0);
 
   PyBobLearnMiscPLDAMachineObject* plda_machine = 0;
   PyBlitzArrayObject* data = 0;
@@ -401,6 +585,46 @@ static PyObject* PyBobLearnMiscPLDATrainer_enrol(PyBobLearnMiscPLDATrainerObject
 }
 
 
+/*** is_similar_to ***/
+static auto is_similar_to = bob::extension::FunctionDoc(
+  "is_similar_to",
+  
+  "Compares this PLDATrainer with the ``other`` one to be approximately the same.",
+  "The optional values ``r_epsilon`` and ``a_epsilon`` refer to the "
+  "relative and absolute precision for the ``weights``, ``biases`` "
+  "and any other values internal to this machine."
+)
+.add_prototype("other, [r_epsilon], [a_epsilon]","output")
+.add_parameter("other", ":py:class:`bob.learn.misc.PLDAMachine`", "A PLDAMachine object to be compared.")
+.add_parameter("r_epsilon", "float", "Relative precision.")
+.add_parameter("a_epsilon", "float", "Absolute precision.")
+.add_return("output","bool","True if it is similar, otherwise false.");
+static PyObject* PyBobLearnMiscPLDATrainer_IsSimilarTo(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwds) {
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = is_similar_to.kwlist(0);
+
+  //PyObject* other = 0;
+  PyBobLearnMiscPLDATrainerObject* other = 0;
+  double r_epsilon = 1.e-5;
+  double a_epsilon = 1.e-8;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|dd", kwlist,
+        &PyBobLearnMiscPLDATrainer_Type, &other,
+        &r_epsilon, &a_epsilon)){
+
+        is_similar_to.print_usage(); 
+        return 0;        
+  }
+
+  if (self->cxx->is_similar_to(*other->cxx, r_epsilon, a_epsilon))
+    Py_RETURN_TRUE;
+  else
+    Py_RETURN_FALSE;
+}
+
+
+
 static PyMethodDef PyBobLearnMiscPLDATrainer_methods[] = {
   {
     initialize.name(),
@@ -420,12 +644,24 @@ static PyMethodDef PyBobLearnMiscPLDATrainer_methods[] = {
     METH_VARARGS|METH_KEYWORDS,
     m_step.doc()
   },
+  {
+    finalize.name(),
+    (PyCFunction)PyBobLearnMiscPLDATrainer_finalize,
+    METH_VARARGS|METH_KEYWORDS,
+    finalize.doc()
+  },  
   {
     enrol.name(),
     (PyCFunction)PyBobLearnMiscPLDATrainer_enrol,
     METH_VARARGS|METH_KEYWORDS,
     enrol.doc()
   },
+  {
+    is_similar_to.name(),
+    (PyCFunction)PyBobLearnMiscPLDATrainer_IsSimilarTo,
+    METH_VARARGS|METH_KEYWORDS,
+    is_similar_to.doc()
+  },
   {0} /* Sentinel */
 };
 
diff --git a/bob/learn/misc/test_plda_trainer.py b/bob/learn/misc/test_plda_trainer.py
index f18fe33..553ff4f 100644
--- a/bob/learn/misc/test_plda_trainer.py
+++ b/bob/learn/misc/test_plda_trainer.py
@@ -123,8 +123,8 @@ class PythonPLDATrainer():
 
   def __init_sigma__(self, machine, data, factor = 1.):
     """As a variance of the data"""
-    cache1 = numpy.zeros(shape=(machine.dim_d,), dtype=numpy.float64)
-    cache2 = numpy.zeros(shape=(machine.dim_d,), dtype=numpy.float64)
+    cache1 = numpy.zeros(shape=(machine.shape[0],), dtype=numpy.float64)
+    cache2 = numpy.zeros(shape=(machine.shape[0],), dtype=numpy.float64)
     n_samples = 0
     for v in data:
       for j in range(v.shape[0]):
@@ -145,10 +145,10 @@ class PythonPLDATrainer():
   def initialize(self, machine, data):
     self.__check_training_data__(data)
     n_features = data[0].shape[1]
-    if(machine.dim_d != n_features):
+    if(machine.shape[0] != n_features):
       raise RuntimeError("Inconsistent feature dimensionality between the machine and the training data set")
-    self.m_dim_f = machine.dim_f
-    self.m_dim_g = machine.dim_g
+    self.m_dim_f = machine.shape[1]
+    self.m_dim_g = machine.shape[2]
     self.__init_members__(data)
     # Warning: Default initialization of mu, F, G, sigma using scatters
     self.__init_mu_f_g_sigma__(machine, data)
@@ -237,7 +237,7 @@ class PythonPLDATrainer():
 
   def __update_f_and_g__(self, machine, data):
     ### Initialise the numerator and the denominator.
-    dim_d                          = machine.dim_d
+    dim_d                          = machine.shape[0]
     accumulated_B_numerator        = numpy.zeros((dim_d,self.m_dim_f+self.m_dim_g))
     accumulated_B_denominator      = numpy.linalg.inv(self.m_sum_z_second_order)
     mu                             = machine.mu
@@ -263,7 +263,7 @@ class PythonPLDATrainer():
 
   def __update_sigma__(self, machine, data):
     ### Initialise the accumulated Sigma
-    dim_d                          = machine.dim_d
+    dim_d                          = machine.shape[0]
     mu                             = machine.mu
     accumulated_sigma              = numpy.zeros(dim_d)   # An array (dim_d)
     number_of_observations         = 0
@@ -368,9 +368,9 @@ def test_plda_EM_vs_Python():
   m_py = PLDABase(D,nf,ng)
 
   # Sets the same initialization methods
-  t.init_f_method = PLDATrainer.BETWEEN_SCATTER
-  t.init_g_method = PLDATrainer.WITHIN_SCATTER
-  t.init_sigma_method = PLDATrainer.VARIANCE_DATA
+  t.init_f_method = 'BETWEEN_SCATTER'
+  t.init_g_method = 'WITHIN_SCATTER'
+  t.init_sigma_method = 'VARIANCE_DATA'
 
   t.train(m, l)
   t_py.train(m_py, l)
@@ -379,6 +379,7 @@ def test_plda_EM_vs_Python():
   assert numpy.allclose(m.g, m_py.g)
   assert numpy.allclose(m.sigma, m_py.sigma)
 
+
 def test_plda_EM_vs_Prince():
   # Data used for performing the tests
   # Features and subspaces dimensionality
@@ -687,23 +688,28 @@ def test_plda_enrollment():
   t = PLDATrainer()
   t.enrol(m, a_enrol)
   ll = m.compute_log_likelihood(x3)
+  
   assert abs(ll - ll_ref) < 1e-10
 
   # reference obtained by computing the likelihood of [x1,x2,x3], [x1,x2]
   # and [x3] separately
   llr_ref = -4.43695386675
-  llr = m.forward(x3)
+  llr = m(x3)
   assert abs(llr - llr_ref) < 1e-10
   #
   llr_separate = m.compute_log_likelihood(numpy.array([x1,x2,x3]), False) - \
     (m.compute_log_likelihood(numpy.array([x1,x2]), False) + m.compute_log_likelihood(numpy.array([x3]), False))
   assert abs(llr - llr_separate) < 1e-10
 
+
+
 def test_plda_comparisons():
 
   t1 = PLDATrainer()
   t2 = PLDATrainer()
-  t2.rng = t1.rng
+
+  #t2.rng = t1.rng
+
   assert t1 == t2
   assert (t1 != t2 ) is False
   assert t1.is_similar_to(t2)
@@ -731,3 +737,5 @@ def test_plda_comparisons():
   assert (t1 == t2 ) is False
   assert t1 != t2
   assert (t1.is_similar_to(t2) ) is False
+
+  
diff --git a/bob/learn/misc/ztnorm.cpp b/bob/learn/misc/ztnorm.cpp
index fa8ec36..9e2c6ea 100644
--- a/bob/learn/misc/ztnorm.cpp
+++ b/bob/learn/misc/ztnorm.cpp
@@ -9,9 +9,6 @@
 
 #include "main.h"
 
-#ifndef BOB_LEARN_MISC_ZTNORM_BIND
-#define BOB_LEARN_MISC_ZTNORM_BIND
-
 /*** zt_norm ***/
 static auto zt_norm = bob::extension::FunctionDoc(
   "ztnorm",
@@ -147,5 +144,3 @@ static PyObject* PyBobLearnMisc_zNorm(PyObject*, PyObject* args, PyObject* kwarg
   return PyBlitzArrayCxx_AsConstNumpy(normalized_scores);
 }
 
-#endif
-
diff --git a/setup.py b/setup.py
index 51e44e4..aa58b9c 100644
--- a/setup.py
+++ b/setup.py
@@ -137,6 +137,8 @@ setup(
 
           "bob/learn/misc/ztnorm.cpp",
 
+          "bob/learn/misc/linear_scoring.cpp",
+
           "bob/learn/misc/main.cpp",
         ],
         bob_packages = bob_packages,
-- 
GitLab