Skip to content
Snippets Groups Projects
Commit 3552d034 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Tests with PLDATrainer

parent d9093bee
No related branches found
No related tags found
No related merge requests found
......@@ -17,6 +17,7 @@ from .__MAP_gmm_trainer__ import *
from .__jfa_trainer__ import *
from .__isv_trainer__ import *
from .__ivector_trainer__ import *
from .__plda_trainer__ import *
def ztnorm_same_value(vect_a, vect_b):
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# Mon Fev 02 21:40:10 2015 +0200
#
# Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland
from ._library import _PLDATrainer
import numpy
# define the class
class PLDATrainer (_PLDATrainer):
def __init__(self, max_iterations=10, use_sum_second_order=True):
"""
:py:class:`bob.learn.misc.PLDATrainer` constructor
Keyword Parameters:
max_iterations
Number of maximum iterations
"""
_PLDATrainer.__init__(self, use_sum_second_order)
self._max_iterations = max_iterations
def train(self, plda_base, data):
"""
Train the :py:class:`bob.learn.misc.PLDABase` using data
Keyword Parameters:
jfa_base
The `:py:class:bob.learn.misc.PLDABase` class
data
The data to be trained
"""
#Initialization
self.initialize(plda_base, data);
for i in range(self._max_iterations):
#eStep
self.eStep(plda_base, data);
#mStep
self.mStep(plda_base);
self.finalize(plda_base);
# copy the documentation from the base class
__doc__ = _PLDATrainer.__doc__
......@@ -81,7 +81,8 @@ static PyObject* create_module (void) {
if (!init_BobLearnMiscIVectorTrainer(module)) return 0;
if (!init_BobLearnMiscPLDABase(module)) return 0;
if (!init_BobLearnMiscPLDAMachine(module)) return 0;
if (!init_BobLearnMiscPLDAMachine(module)) return 0;
if (!init_BobLearnMiscPLDATrainer(module)) return 0;
if (!init_BobLearnMiscEMPCATrainer(module)) return 0;
......
......@@ -33,11 +33,25 @@ int list_as_vector(PyObject* list, std::vector<blitz::Array<double,N> >& vec)
}
template <int N>
static PyObject* vector_as_list(const std::vector<blitz::Array<double,N> >& vec)
{
PyObject* list = PyList_New(vec.size());
for(size_t i=0; i<vec.size(); i++){
blitz::Array<double,N> numpy_array = vec[i];
PyObject* numpy_py_object = PyBlitzArrayCxx_AsNumpy(numpy_array);
PyList_SET_ITEM(list, i, numpy_py_object);
}
return list;
}
static auto PLDATrainer_doc = bob::extension::ClassDoc(
BOB_EXT_MODULE_PREFIX ".PLDATrainer",
"This class can be used to train the :math:`$F$`, :math:`$G$ and "
" :math:`$\\Sigma$` matrices and the mean vector :math:`$\\mu$` of a PLDA model.",
" :math:`$\\Sigma$` matrices and the mean vector :math:`$\\mu$` of a PLDA model."
"References: [ElShafey2014,PrinceElder2007,LiFu2012]",
""
).add_constructor(
bob::extension::FunctionDoc(
"__init__",
......@@ -89,35 +103,29 @@ static int PyBobLearnMiscPLDATrainer_init(PyBobLearnMiscPLDATrainerObject* self,
// get the number of command line arguments
int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
switch(nargs){
case 0:{
self->cxx.reset(new bob::learn::misc::PLDATrainer());
return 0;
if(nargs==1){
//Reading the input argument
PyObject* arg = 0;
if (PyTuple_Size(args))
arg = PyTuple_GET_ITEM(args, 0);
else {
PyObject* tmp = PyDict_Values(kwargs);
auto tmp_ = make_safe(tmp);
arg = PyList_GET_ITEM(tmp, 0);
}
case 1:{
//Reading the input argument
PyObject* arg = 0;
if (PyTuple_Size(args))
arg = PyTuple_GET_ITEM(args, 0);
else {
PyObject* tmp = PyDict_Values(kwargs);
auto tmp_ = make_safe(tmp);
arg = PyList_GET_ITEM(tmp, 0);
}
if(PyBobLearnMiscPLDATrainer_Check(arg))
// If the constructor input is PLDATrainer object
return PyBobLearnMiscPLDATrainer_init_copy(self, args, kwargs);
else
return PyBobLearnMiscPLDATrainer_init_bool(self, args, kwargs);
}
default:{
PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires only 0 or 1 argument, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs);
PLDATrainer_doc.print_usage();
return -1;
}
if(PyBobLearnMiscPLDATrainer_Check(arg))
// If the constructor input is PLDATrainer object
return PyBobLearnMiscPLDATrainer_init_copy(self, args, kwargs);
else
return PyBobLearnMiscPLDATrainer_init_bool(self, args, kwargs);
}
else{
PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires only 0 or 1 argument, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs);
PLDATrainer_doc.print_usage();
return -1;
}
BOB_CATCH_MEMBER("cannot create PLDATrainer", 0)
return 0;
}
......@@ -167,7 +175,8 @@ static auto z_second_order = bob::extension::VariableDoc(
);
PyObject* PyBobLearnMiscPLDATrainer_get_z_second_order(PyBobLearnMiscPLDATrainerObject* self, void*){
BOB_TRY
return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getZSecondOrder());
//return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getZSecondOrder());
return vector_as_list(self->cxx->getZSecondOrder());
BOB_CATCH_MEMBER("z_second_order could not be read", 0)
}
......@@ -193,7 +202,8 @@ static auto z_first_order = bob::extension::VariableDoc(
);
PyObject* PyBobLearnMiscPLDATrainer_get_z_first_order(PyBobLearnMiscPLDATrainerObject* self, void*){
BOB_TRY
return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getZFirstOrder());
//return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getZFirstOrder());
return vector_as_list(self->cxx->getZFirstOrder());
BOB_CATCH_MEMBER("z_first_order could not be read", 0)
}
......@@ -255,7 +265,7 @@ static PyObject* PyBobLearnMiscPLDATrainer_initialize(PyBobLearnMiscPLDATrainerO
std::vector<blitz::Array<double,2> > data_vector;
if(list_as_vector(data ,data_vector)==0)
self->cxx->initialize(*plda_machine->cxx, data_vector);
self->cxx->initialize(*plda_base->cxx, data_vector);
BOB_CATCH_MEMBER("cannot perform the initialize method", 0)
......@@ -287,7 +297,7 @@ static PyObject* PyBobLearnMiscPLDATrainer_e_step(PyBobLearnMiscPLDATrainerObjec
std::vector<blitz::Array<double,2> > data_vector;
if(list_as_vector(data ,data_vector)==0)
self->cxx->e_step(*plda_machine->cxx, data_vector);
self->cxx->eStep(*plda_base->cxx, data_vector);
BOB_CATCH_MEMBER("cannot perform the e_step method", 0)
......@@ -319,7 +329,7 @@ static PyObject* PyBobLearnMiscPLDATrainer_m_step(PyBobLearnMiscPLDATrainerObjec
std::vector<blitz::Array<double,2> > data_vector;
if(list_as_vector(data ,data_vector)==0)
self->cxx->m_step(*plda_machine->cxx, data_vector);
self->cxx->mStep(*plda_base->cxx, data_vector);
BOB_CATCH_MEMBER("cannot perform the m_step method", 0)
......@@ -351,7 +361,7 @@ static PyObject* PyBobLearnMiscPLDATrainer_finalize(PyBobLearnMiscPLDATrainerObj
std::vector<blitz::Array<double,2> > data_vector;
if(list_as_vector(data ,data_vector)==0)
self->cxx->finalize(*plda_machine->cxx, data_vector);
self->cxx->finalize(*plda_base->cxx, data_vector);
BOB_CATCH_MEMBER("cannot perform the finalize method", 0)
......@@ -370,7 +380,7 @@ static auto enrol = bob::extension::FunctionDoc(
.add_prototype("plda_machine,data")
.add_parameter("plda_machine", ":py:class:`bob.learn.misc.PLDAMachine`", "PLDAMachine Object")
.add_parameter("data", "list", "");
static PyObject* PyBobLearnMiscPLDATrainer_finalize(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
static PyObject* PyBobLearnMiscPLDATrainer_enrol(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
/* Parses input arguments in a single shot */
......
......@@ -77,8 +77,8 @@ class PythonPLDATrainer():
def __init_f__(self, machine, data):
n_ids = len(data)
S = numpy.zeros(shape=(machine.dim_d, n_ids), dtype=numpy.float64)
Si_sum = numpy.zeros(shape=(machine.dim_d,), dtype=numpy.float64)
S = numpy.zeros(shape=(machine.shape[0], n_ids), dtype=numpy.float64)
Si_sum = numpy.zeros(shape=(machine.shape[0],), dtype=numpy.float64)
for i in range(n_ids):
Si = S[:,i]
data_i = data[i]
......@@ -88,7 +88,7 @@ class PythonPLDATrainer():
Si_sum += Si
Si_sum /= n_ids
S = S - numpy.tile(Si_sum.reshape([machine.dim_d,1]), [1,n_ids])
S = S - numpy.tile(Si_sum.reshape([machine.shape[0],1]), [1,n_ids])
U, sigma, S_ = numpy.linalg.svd(S, full_matrices=False)
U_slice = U[:,0:self.m_dim_f]
sigma_slice = sigma[0:self.m_dim_f]
......@@ -99,9 +99,9 @@ class PythonPLDATrainer():
n_samples = 0
for v in data:
n_samples += v.shape[0]
S = numpy.zeros(shape=(machine.dim_d, n_samples), dtype=numpy.float64)
Si_sum = numpy.zeros(shape=(machine.dim_d,), dtype=numpy.float64)
cache = numpy.zeros(shape=(machine.dim_d,), dtype=numpy.float64)
S = numpy.zeros(shape=(machine.shape[0], n_samples), dtype=numpy.float64)
Si_sum = numpy.zeros(shape=(machine.shape[0],), dtype=numpy.float64)
cache = numpy.zeros(shape=(machine.shape[0],), dtype=numpy.float64)
c = 0
for i in range(len(data)):
cache = 0
......@@ -115,7 +115,7 @@ class PythonPLDATrainer():
c += 1
Si_sum /= n_samples
S = S - numpy.tile(Si_sum.reshape([machine.dim_d,1]), [1,n_samples])
S = S - numpy.tile(Si_sum.reshape([machine.shape[0],1]), [1,n_samples])
U, sigma, S_ = numpy.linalg.svd(S, full_matrices=False)
U_slice = U[:,0:self.m_dim_g]
sigma_slice_sqrt = numpy.sqrt(sigma[0:self.m_dim_g])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment