Commit 0be70faf authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Binding linear scoring

parent 3552d034
......@@ -11,7 +11,7 @@ import numpy
# define the class
class PLDATrainer (_PLDATrainer):
def __init__(self, max_iterations=10, use_sum_second_order=True):
def __init__(self, max_iterations=10, use_sum_second_order=False):
"""
:py:class:`bob.learn.misc.PLDATrainer` constructor
......@@ -39,10 +39,10 @@ class PLDATrainer (_PLDATrainer):
for i in range(self._max_iterations):
#eStep
self.eStep(plda_base, data);
self.e_step(plda_base, data);
#mStep
self.mStep(plda_base);
self.finalize(plda_base);
self.m_step(plda_base, data);
self.finalize(plda_base, data);
......
......@@ -217,6 +217,19 @@ class PLDATrainer
*/
void enrol(bob::learn::misc::PLDAMachine& plda_machine,
const blitz::Array<double,2>& ar) const;
/**
* @brief Sets the Random Number Generator
*/
void setRng(const boost::shared_ptr<boost::mt19937> rng)
{ m_rng = rng; }
/**
* @brief Gets the Random Number Generator
*/
const boost::shared_ptr<boost::mt19937> getRng() const
{ return m_rng; }
private:
......
......@@ -282,7 +282,7 @@ int PyBobLearnMiscKMeansTrainer_setRng(PyBobLearnMiscKMeansTrainerObject* self,
BOB_TRY
if (!PyBoostMt19937_Check(value)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects an PyBoostMt19937_Check", Py_TYPE(self)->tp_name, average_min_distance.name());
PyErr_Format(PyExc_RuntimeError, "%s %s expects an PyBoostMt19937_Check", Py_TYPE(self)->tp_name, rng.name());
return -1;
}
......
/**
* @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
* @date Wed 05 Feb 16:10:48 2015
*
* @brief Python API for bob::learn::em
*
* Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
*/
#include "main.h"
/*Convert a PyObject to a a list of GMMStats*/
//template<class R, class P1, class P2>
static int extract_gmmstats_list(PyObject *list,
std::vector<boost::shared_ptr<const bob::learn::misc::GMMStats> >& training_data)
{
for (int i=0; i<PyList_GET_SIZE(list); i++){
PyBobLearnMiscGMMStatsObject* stats;
if (!PyArg_Parse(PyList_GetItem(list, i), "O!", &PyBobLearnMiscGMMStats_Type, &stats)){
PyErr_Format(PyExc_RuntimeError, "Expected GMMStats objects");
return -1;
}
training_data.push_back(stats->cxx);
}
return 0;
}
static int extract_gmmmachine_list(PyObject *list,
std::vector<boost::shared_ptr<const bob::learn::misc::GMMMachine> >& training_data)
{
for (int i=0; i<PyList_GET_SIZE(list); i++){
PyBobLearnMiscGMMMachineObject* stats;
if (!PyArg_Parse(PyList_GetItem(list, i), "O!", &PyBobLearnMiscGMMMachine_Type, &stats)){
PyErr_Format(PyExc_RuntimeError, "Expected GMMMachine objects");
return -1;
}
training_data.push_back(stats->cxx);
}
return 0;
}
/*Convert a PyObject to a list of blitz Array*/
template <int N>
int extract_array_list(PyObject* list, std::vector<blitz::Array<double,N> >& vec)
{
for (int i=0; i<PyList_GET_SIZE(list); i++)
{
PyBlitzArrayObject* blitz_object;
if (!PyArg_Parse(PyList_GetItem(list, i), "O&", &PyBlitzArray_Converter, &blitz_object)){
PyErr_Format(PyExc_RuntimeError, "Expected numpy array object");
return -1;
}
auto blitz_object_ = make_safe(blitz_object);
vec.push_back(*PyBlitzArrayCxx_AsBlitz<double,N>(blitz_object));
}
return 0;
}
/* converts PyObject to bool and returns false if object is NULL */
static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;}
/*** linear_scoring ***/
static auto linear_scoring = bob::extension::FunctionDoc(
"linear_scoring",
"",
0,
true
)
.add_prototype("models, ubm_mean, ubm_variance, test_stats, test_channelOffset, frame_length_normalisation", "output")
.add_parameter("models", "", "")
.add_parameter("ubm", "", "")
.add_parameter("test_stats", "", "")
.add_parameter("test_channelOffset", "", "")
.add_parameter("frame_length_normalisation", "bool", "")
.add_return("output","array_like<float,2>","Score");
static PyObject* PyBobLearnMisc_linear_scoring(PyObject*, PyObject* args, PyObject* kwargs) {
char** kwlist = linear_scoring.kwlist(0);
//Cheking the number of arguments
int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
switch(nargs){
//Read a list of GMM
case 5:{
PyObject* gmm_list_o = 0;
PyBobLearnMiscGMMMachineObject* ubm = 0;
PyObject* stats_list_o = 0;
PyObject* channel_offset_list_o = 0;
PyObject* frame_length_normalisation = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!O!O!O!", kwlist, &PyList_Type, &gmm_list_o,
&PyBobLearnMiscGMMMachine_Type, &ubm,
&PyList_Type, &stats_list_o,
&PyList_Type, &channel_offset_list_o,
&PyBool_Type, frame_length_normalisation)){
linear_scoring.print_usage();
Py_RETURN_NONE;
}
std::vector<boost::shared_ptr<const bob::learn::misc::GMMStats> > stats_list;
if(extract_gmmstats_list(stats_list_o ,stats_list)!=0)
Py_RETURN_NONE;
std::vector<boost::shared_ptr<const bob::learn::misc::GMMMachine> > gmm_list;
if(extract_gmmmachine_list(gmm_list_o ,gmm_list)!=0)
Py_RETURN_NONE;
std::vector<blitz::Array<double,2> > channel_offset_list;
if(extract_array_list(channel_offset_list_o ,channel_offset_list)!=0)
Py_RETURN_NONE;
blitz::Array<double, 2> scores = blitz::Array<double, 2>(gmm_list.size(), stats_list.size());
bob::learn::misc::linearScoring(gmm_list, *ubm->cxx, stats_list, channel_offset_list, f(frame_length_normalisation),scores);
return PyBlitzArrayCxx_AsConstNumpy(scores);
}
default:{
PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - linear_scoring requires 5 or 6 arguments, but you provided %d (see help)", nargs);
linear_scoring.print_usage();
Py_RETURN_NONE;
}
}
/*
PyBlitzArrayObject *rawscores_probes_vs_models_o, *rawscores_zprobes_vs_models_o, *rawscores_probes_vs_tmodels_o,
*rawscores_zprobes_vs_tmodels_o, *mask_zprobes_vs_tmodels_istruetrial_o;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&O&O&O&|O&", kwlist, &PyBlitzArray_Converter, &rawscores_probes_vs_models_o,
&PyBlitzArray_Converter, &rawscores_zprobes_vs_models_o,
&PyBlitzArray_Converter, &rawscores_probes_vs_tmodels_o,
&PyBlitzArray_Converter, &rawscores_zprobes_vs_tmodels_o,
&PyBlitzArray_Converter, &mask_zprobes_vs_tmodels_istruetrial_o)){
zt_norm.print_usage();
Py_RETURN_NONE;
}
// get the number of command line arguments
auto rawscores_probes_vs_models_ = make_safe(rawscores_probes_vs_models_o);
auto rawscores_zprobes_vs_models_ = make_safe(rawscores_zprobes_vs_models_o);
auto rawscores_probes_vs_tmodels_ = make_safe(rawscores_probes_vs_tmodels_o);
auto rawscores_zprobes_vs_tmodels_ = make_safe(rawscores_zprobes_vs_tmodels_o);
//auto mask_zprobes_vs_tmodels_istruetrial_ = make_safe(mask_zprobes_vs_tmodels_istruetrial_o);
blitz::Array<double,2> rawscores_probes_vs_models = *PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_probes_vs_models_o);
blitz::Array<double,2> normalized_scores = blitz::Array<double,2>(rawscores_probes_vs_models.extent(0), rawscores_probes_vs_models.extent(1));
int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
if(nargs==4)
bob::learn::misc::ztNorm(*PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_probes_vs_models_o),
*PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_zprobes_vs_models_o),
*PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_probes_vs_tmodels_o),
*PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_zprobes_vs_tmodels_o),
normalized_scores);
else
bob::learn::misc::ztNorm(*PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_probes_vs_models_o),
*PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_zprobes_vs_models_o),
*PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_probes_vs_tmodels_o),
*PyBlitzArrayCxx_AsBlitz<double,2>(rawscores_zprobes_vs_tmodels_o),
*PyBlitzArrayCxx_AsBlitz<bool,2>(mask_zprobes_vs_tmodels_istruetrial_o),
normalized_scores);
return PyBlitzArrayCxx_AsConstNumpy(normalized_scores);
*/
}
......@@ -9,6 +9,9 @@
#undef NO_IMPORT_ARRAY
#endif
#include "main.h"
#include "ztnorm.cpp"
#include "linear_scoring.cpp"
static PyMethodDef module_methods[] = {
{
......
......@@ -50,8 +50,6 @@
#include <bob.learn.misc/ZTNorm.h>
#include "ztnorm.cpp"
#if PY_VERSION_HEX >= 0x03000000
#define PyInt_Check PyLong_Check
......
......@@ -643,7 +643,7 @@ static PyObject* PyBobLearnMiscPLDAMachine_computeLogLikelihood(PyBobLearnMiscPL
char** kwlist = compute_log_likelihood.kwlist(0);
PyBlitzArrayObject* samples;
PyObject* with_enrolled_samples = 0;
PyObject* with_enrolled_samples = Py_True;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&|O!", kwlist, &PyBlitzArray_Converter, &samples,
&PyBool_Type, &with_enrolled_samples)) Py_RETURN_NONE;
......
......@@ -10,9 +10,50 @@
#include "main.h"
#include <boost/make_shared.hpp>
/******************************************************************/
/************ Constructor Section *********************************/
/******************************************************************/
//Defining maps for each initializatio method
static const std::map<std::string, bob::learn::misc::PLDATrainer::InitFMethod> FMethod = {{"RANDOM_F", bob::learn::misc::PLDATrainer::RANDOM_F}, {"BETWEEN_SCATTER", bob::learn::misc::PLDATrainer::BETWEEN_SCATTER}};
static const std::map<std::string, bob::learn::misc::PLDATrainer::InitGMethod> GMethod = {{"RANDOM_G", bob::learn::misc::PLDATrainer::RANDOM_G}, {"WITHIN_SCATTER", bob::learn::misc::PLDATrainer::WITHIN_SCATTER}};
static const std::map<std::string, bob::learn::misc::PLDATrainer::InitSigmaMethod> SigmaMethod = {{"RANDOM_SIGMA", bob::learn::misc::PLDATrainer::RANDOM_SIGMA}, {"VARIANCE_G", bob::learn::misc::PLDATrainer::VARIANCE_G}, {"CONSTANT", bob::learn::misc::PLDATrainer::CONSTANT}, {"VARIANCE_DATA", bob::learn::misc::PLDATrainer::VARIANCE_DATA}};
//String to type
static inline bob::learn::misc::PLDATrainer::InitFMethod string2FMethod(const std::string& o){
auto it = FMethod.find(o);
if (it == FMethod.end()) throw std::runtime_error("The given FMethod '" + o + "' is not known; choose one of ('RANDOM_F','BETWEEN_SCATTER')");
else return it->second;
}
static inline bob::learn::misc::PLDATrainer::InitGMethod string2GMethod(const std::string& o){
auto it = GMethod.find(o);
if (it == GMethod.end()) throw std::runtime_error("The given GMethod '" + o + "' is not known; choose one of ('RANDOM_G','WITHIN_SCATTER')");
else return it->second;
}
static inline bob::learn::misc::PLDATrainer::InitSigmaMethod string2SigmaMethod(const std::string& o){
auto it = SigmaMethod.find(o);
if (it == SigmaMethod.end()) throw std::runtime_error("The given SigmaMethod '" + o + "' is not known; choose one of ('RANDOM_SIGMA','VARIANCE_G', 'CONSTANT', 'VARIANCE_DATA')");
else return it->second;
}
//Type to string
static inline const std::string& FMethod2string(bob::learn::misc::PLDATrainer::InitFMethod o){
for (auto it = FMethod.begin(); it != FMethod.end(); ++it) if (it->second == o) return it->first;
throw std::runtime_error("The given FMethod type is not known");
}
static inline const std::string& GMethod2string(bob::learn::misc::PLDATrainer::InitGMethod o){
for (auto it = GMethod.begin(); it != GMethod.end(); ++it) if (it->second == o) return it->first;
throw std::runtime_error("The given GMethod type is not known");
}
static inline const std::string& SigmaMethod2string(bob::learn::misc::PLDATrainer::InitSigmaMethod o){
for (auto it = SigmaMethod.begin(); it != SigmaMethod.end(); ++it) if (it->second == o) return it->first;
throw std::runtime_error("The given SigmaMethod type is not known");
}
static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;} /* converts PyObject to bool and returns false if object is NULL */
......@@ -46,6 +87,11 @@ static PyObject* vector_as_list(const std::vector<blitz::Array<double,N> >& vec)
}
/******************************************************************/
/************ Constructor Section *********************************/
/******************************************************************/
static auto PLDATrainer_doc = bob::extension::ClassDoc(
BOB_EXT_MODULE_PREFIX ".PLDATrainer",
"This class can be used to train the :math:`$F$`, :math:`$G$ and "
......@@ -208,6 +254,117 @@ PyObject* PyBobLearnMiscPLDATrainer_get_z_first_order(PyBobLearnMiscPLDATrainerO
}
/***** rng *****/
static auto rng = bob::extension::VariableDoc(
"rng",
"str",
"The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop.",
""
);
PyObject* PyBobLearnMiscPLDATrainer_getRng(PyBobLearnMiscPLDATrainerObject* self, void*) {
BOB_TRY
//Allocating the correspondent python object
PyBoostMt19937Object* retval =
(PyBoostMt19937Object*)PyBoostMt19937_Type.tp_alloc(&PyBoostMt19937_Type, 0);
retval->rng = self->cxx->getRng().get();
return Py_BuildValue("O", retval);
BOB_CATCH_MEMBER("Rng method could not be read", 0)
}
int PyBobLearnMiscPLDATrainer_setRng(PyBobLearnMiscPLDATrainerObject* self, PyObject* value, void*) {
BOB_TRY
if (!PyBoostMt19937_Check(value)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects an PyBoostMt19937_Check", Py_TYPE(self)->tp_name, rng.name());
return -1;
}
PyBoostMt19937Object* rng_object = 0;
PyArg_Parse(value, "O!", &PyBoostMt19937_Type, &rng_object);
self->cxx->setRng((boost::shared_ptr<boost::mt19937>)rng_object->rng);
return 0;
BOB_CATCH_MEMBER("Rng could not be set", 0)
}
/***** init_f_method *****/
static auto init_f_method = bob::extension::VariableDoc(
"init_f_method",
"str",
"The method used for the initialization of :math:`$F$`.",
""
);
PyObject* PyBobLearnMiscPLDATrainer_getFMethod(PyBobLearnMiscPLDATrainerObject* self, void*) {
BOB_TRY
return Py_BuildValue("s", FMethod2string(self->cxx->getInitFMethod()).c_str());
BOB_CATCH_MEMBER("init_f_method method could not be read", 0)
}
int PyBobLearnMiscPLDATrainer_setFMethod(PyBobLearnMiscPLDATrainerObject* self, PyObject* value, void*) {
BOB_TRY
if (!PyString_Check(value)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects an str", Py_TYPE(self)->tp_name, init_f_method.name());
return -1;
}
self->cxx->setInitFMethod(string2FMethod(PyString_AS_STRING(value)));
return 0;
BOB_CATCH_MEMBER("init_f_method method could not be set", 0)
}
/***** init_g_method *****/
static auto init_g_method = bob::extension::VariableDoc(
"init_g_method",
"str",
"The method used for the initialization of :math:`$G$`.",
""
);
PyObject* PyBobLearnMiscPLDATrainer_getGMethod(PyBobLearnMiscPLDATrainerObject* self, void*) {
BOB_TRY
return Py_BuildValue("s", GMethod2string(self->cxx->getInitGMethod()).c_str());
BOB_CATCH_MEMBER("init_g_method method could not be read", 0)
}
int PyBobLearnMiscPLDATrainer_setGMethod(PyBobLearnMiscPLDATrainerObject* self, PyObject* value, void*) {
BOB_TRY
if (!PyString_Check(value)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects an str", Py_TYPE(self)->tp_name, init_g_method.name());
return -1;
}
self->cxx->setInitGMethod(string2GMethod(PyString_AS_STRING(value)));
return 0;
BOB_CATCH_MEMBER("init_g_method method could not be set", 0)
}
/***** init_sigma_method *****/
static auto init_sigma_method = bob::extension::VariableDoc(
"init_sigma_method",
"str",
"The method used for the initialization of :math:`$\\Sigma$`.",
""
);
PyObject* PyBobLearnMiscPLDATrainer_getSigmaMethod(PyBobLearnMiscPLDATrainerObject* self, void*) {
BOB_TRY
return Py_BuildValue("s", SigmaMethod2string(self->cxx->getInitSigmaMethod()).c_str());
BOB_CATCH_MEMBER("init_sigma_method method could not be read", 0)
}
int PyBobLearnMiscPLDATrainer_setSigmaMethod(PyBobLearnMiscPLDATrainerObject* self, PyObject* value, void*) {
BOB_TRY
if (!PyString_Check(value)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects an str", Py_TYPE(self)->tp_name, init_sigma_method.name());
return -1;
}
self->cxx->setInitSigmaMethod(string2SigmaMethod(PyString_AS_STRING(value)));
return 0;
BOB_CATCH_MEMBER("init_sigma_method method could not be set", 0)
}
static PyGetSetDef PyBobLearnMiscPLDATrainer_getseters[] = {
......@@ -232,7 +389,34 @@ static PyGetSetDef PyBobLearnMiscPLDATrainer_getseters[] = {
z_second_order.doc(),
0
},
{
rng.name(),
(getter)PyBobLearnMiscPLDATrainer_getRng,
(setter)PyBobLearnMiscPLDATrainer_setRng,
rng.doc(),
0
},
{
init_f_method.name(),
(getter)PyBobLearnMiscPLDATrainer_getFMethod,
(setter)PyBobLearnMiscPLDATrainer_setFMethod,
init_f_method.doc(),
0
},
{
init_g_method.name(),
(getter)PyBobLearnMiscPLDATrainer_getGMethod,
(setter)PyBobLearnMiscPLDATrainer_setGMethod,
init_g_method.doc(),
0
},
{
init_sigma_method.name(),
(getter)PyBobLearnMiscPLDATrainer_getSigmaMethod,
(setter)PyBobLearnMiscPLDATrainer_setSigmaMethod,
init_sigma_method.doc(),
0
},
{0} // Sentinel
};
......@@ -384,7 +568,7 @@ static PyObject* PyBobLearnMiscPLDATrainer_enrol(PyBobLearnMiscPLDATrainerObject
BOB_TRY
/* Parses input arguments in a single shot */
char** kwlist = finalize.kwlist(0);
char** kwlist = enrol.kwlist(0);
PyBobLearnMiscPLDAMachineObject* plda_machine = 0;
PyBlitzArrayObject* data = 0;
......@@ -401,6 +585,46 @@ static PyObject* PyBobLearnMiscPLDATrainer_enrol(PyBobLearnMiscPLDATrainerObject
}
/*** is_similar_to ***/
static auto is_similar_to = bob::extension::FunctionDoc(
"is_similar_to",
"Compares this PLDATrainer with the ``other`` one to be approximately the same.",
"The optional values ``r_epsilon`` and ``a_epsilon`` refer to the "
"relative and absolute precision for the ``weights``, ``biases`` "
"and any other values internal to this machine."
)
.add_prototype("other, [r_epsilon], [a_epsilon]","output")
.add_parameter("other", ":py:class:`bob.learn.misc.PLDAMachine`", "A PLDAMachine object to be compared.")
.add_parameter("r_epsilon", "float", "Relative precision.")
.add_parameter("a_epsilon", "float", "Absolute precision.")
.add_return("output","bool","True if it is similar, otherwise false.");
static PyObject* PyBobLearnMiscPLDATrainer_IsSimilarTo(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwds) {
/* Parses input arguments in a single shot */
char** kwlist = is_similar_to.kwlist(0);
//PyObject* other = 0;
PyBobLearnMiscPLDATrainerObject* other = 0;
double r_epsilon = 1.e-5;
double a_epsilon = 1.e-8;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|dd", kwlist,
&PyBobLearnMiscPLDATrainer_Type, &other,
&r_epsilon, &a_epsilon)){
is_similar_to.print_usage();
return 0;
}
if (self->cxx->is_similar_to(*other->cxx, r_epsilon, a_epsilon))
Py_RETURN_TRUE;
else
Py_RETURN_FALSE;
}
static PyMethodDef PyBobLearnMiscPLDATrainer_methods[] = {
{
initialize.name(),
......@@ -420,12 +644,24 @@ static PyMethodDef PyBobLearnMiscPLDATrainer_methods[] = {
METH_VARARGS|METH_KEYWORDS,
m_step.doc()
},
{
finalize.name(),
(PyCFunction)PyBobLearnMiscPLDATrainer_finalize,
METH_VARARGS|METH_KEYWORDS,
finalize.doc()
},
{
enrol.name(),
(PyCFunction)PyBobLearnMiscPLDATrainer_enrol,
METH_VARARGS|METH_KEYWORDS,
enrol.doc()
},
{
is_similar_to.name(),
(PyCFunction)PyBobLearnMiscPLDATrainer_IsSimilarTo,
METH_VARARGS|METH_KEYWORDS,
is_similar_to.doc()
},
{0} /* Sentinel */
};
......
......@@ -123,8 +123,8 @@ class PythonPLDATrainer():
def __init_sigma__(self, machine, data, factor = 1.):
"""As a variance of the data"""
cache1 = numpy.zeros(shape=(machine.dim_d,), dtype=numpy.float64)
cache2 = numpy.zeros(shape=(machine.dim_d,), dtype=numpy.float64)
cache1 = numpy.zeros(shape=(machine.shape[0],), dtype=numpy.float64)
cache2 = numpy.zeros(shape=(machine.shape[0],), dtype=numpy.float64)
n_samples = 0
for v in data:
for j in range(v.shape[0]):
......@@ -145,10 +145,10 @@ class PythonPLDATrainer():
def initialize(self, machine, data):
self.__check_training_data__(data)
n_features = data[0].shape[1]
if(machine.dim_d != n_features):
if(machine.shape[0] != n_features):
raise RuntimeError("Inconsistent feature dimensionality between the machine and the training data set")
self.m_dim_f = machine.dim_f
self.m_dim_g = machine.dim_g
self.m_dim_f = machine.shape[1]
self.m_dim_g = machine.shape[2]
self.__init_members__(data)
# Warning: Default initialization of mu, F, G, sigma using scatters
self.__init_mu_f_g_sigma__(machine, data)
......@@ -237,7 +237,7 @@ class PythonPLDATrainer():
def __update_f_and_g__(self, machine, data):
### Initialise the numerator and the denominator.
dim_d = machine.dim_d
dim_d = machine.shape[0]
accumulated_B_numerator = numpy.zeros((dim_d,self.m_dim_f+self.m_dim_g))
accumulated_B_denominator = numpy.linalg.inv(self.m_sum_z_second_order)
mu = machine.mu
......@@ -263,7 +263,7 @@ class PythonPLDATrainer():
def __update_sigma__(self, machine, data):
### Initialise the accumulated Sigma
dim_d = machine.dim_d
dim_d = machine.shape[0]
mu = machine.mu
accumulated_sigma = numpy.zeros(dim_d) # An array (dim_d)