Commit d9093bee authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Binding PLDATrainer

parent 115b521b
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# Wed Fev 04 13:35:10 2015 +0200
#
# Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland
from ._library import _EMPCATrainer
import numpy
# define the class
class EMPCATrainer (_EMPCATrainer):
def __init__(self, convergence_threshold=0.001, max_iterations=10, compute_likelihood=True):
"""
:py:class:`bob.learn.misc.EMPCATrainer` constructor
Keyword Parameters:
convergence_threshold
Convergence threshold
max_iterations
Number of maximum iterations
compute_likelihood
"""
_EMPCATrainer.__init__(self,convergence_threshold)
self._max_iterations = max_iterations
self._compute_likelihood = compute_likelihood
def train(self, linear_machine, data):
"""
Train the :py:class:bob.learn.misc.LinearMachine using data
Keyword Parameters:
linear_machine
The :py:class:bob.learn.misc.LinearMachine class
data
The data to be trained
"""
#Initialization
self.initialize(linear_machine, data);
#Do the Expectation-Maximization algorithm
average_output_previous = 0
average_output = -numpy.inf;
#eStep
self.eStep(linear_machine, data);
if(self._compute_likelihood):
average_output = self.compute_likelihood(linear_machine);
for i in range(self._max_iterations):
#saves average output from last iteration
average_output_previous = average_output;
#mStep
self.mStep(linear_machine);
#eStep
self.eStep(linear_machine, data);
#Computes log likelihood if required
if(self._compute_likelihood):
average_output = self.compute_likelihood(linear_machine);
#Terminates if converged (and likelihood computation is set)
if abs((average_output_previous - average_output)/average_output_previous) <= self._convergence_threshold:
break
# copy the documentation from the base class
__doc__ = _EMPCATrainer.__doc__
......@@ -7,20 +7,26 @@
* Copyright (C) Idiap Research Institute, Martigny, Switzerland
*/
#include <bob.learn.misc/PLDATrainer.h>
#include <bob.core/check.h>
#include <bob.core/array_copy.h>
#include <bob.core/array_random.h>
#include <bob.math/linear.h>
#include <bob.math/inv.h>
#include <bob.math/svd.h>
#include <bob.core/check.h>
#include <bob.core/array_repmat.h>
#include <algorithm>
#include <vector>
#include <limits>
#include <vector>
#include <bob.math/linear.h>
#include <bob.math/linsolve.h>
bob::learn::misc::PLDATrainer::PLDATrainer(const size_t max_iterations,
const bool use_sum_second_order):
EMTrainer<bob::learn::misc::PLDABase, std::vector<blitz::Array<double,2> > >
(0.001, max_iterations, false),
bob::learn::misc::PLDATrainer::PLDATrainer(const bool use_sum_second_order):
m_rng(new boost::mt19937()),
m_dim_d(0), m_dim_f(0), m_dim_g(0),
m_use_sum_second_order(use_sum_second_order),
m_initF_method(bob::learn::misc::PLDATrainer::RANDOM_F), m_initF_ratio(1.),
......@@ -38,9 +44,7 @@ bob::learn::misc::PLDATrainer::PLDATrainer(const size_t max_iterations,
}
bob::learn::misc::PLDATrainer::PLDATrainer(const bob::learn::misc::PLDATrainer& other):
EMTrainer<bob::learn::misc::PLDABase, std::vector<blitz::Array<double,2> > >
(other.m_convergence_threshold, other.m_max_iterations,
other.m_compute_likelihood),
m_rng(other.m_rng),
m_dim_d(other.m_dim_d), m_dim_f(other.m_dim_f), m_dim_g(other.m_dim_g),
m_use_sum_second_order(other.m_use_sum_second_order),
m_initF_method(other.m_initF_method), m_initF_ratio(other.m_initF_ratio),
......@@ -71,8 +75,7 @@ bob::learn::misc::PLDATrainer& bob::learn::misc::PLDATrainer::operator=
{
if(this != &other)
{
bob::learn::misc::EMTrainer<bob::learn::misc::PLDABase,
std::vector<blitz::Array<double,2> > >::operator=(other);
m_rng = m_rng,
m_dim_d = other.m_dim_d;
m_dim_f = other.m_dim_f;
m_dim_g = other.m_dim_g;
......@@ -102,8 +105,7 @@ bob::learn::misc::PLDATrainer& bob::learn::misc::PLDATrainer::operator=
bool bob::learn::misc::PLDATrainer::operator==
(const bob::learn::misc::PLDATrainer& other) const
{
return bob::learn::misc::EMTrainer<bob::learn::misc::PLDABase,
std::vector<blitz::Array<double,2> > >::operator==(other) &&
return m_rng == m_rng &&
m_dim_d == other.m_dim_d &&
m_dim_f == other.m_dim_f &&
m_dim_g == other.m_dim_g &&
......@@ -138,8 +140,7 @@ bool bob::learn::misc::PLDATrainer::is_similar_to
(const bob::learn::misc::PLDATrainer &other, const double r_epsilon,
const double a_epsilon) const
{
return bob::learn::misc::EMTrainer<bob::learn::misc::PLDABase,
std::vector<blitz::Array<double,2> > >::is_similar_to(other, r_epsilon, a_epsilon) &&
return m_rng == m_rng &&
m_dim_d == other.m_dim_d &&
m_dim_f == other.m_dim_f &&
m_dim_g == other.m_dim_g &&
......@@ -745,12 +746,6 @@ void bob::learn::misc::PLDATrainer::updateSigma(bob::learn::misc::PLDABase& mach
machine.applyVarianceThreshold();
}
double bob::learn::misc::PLDATrainer::computeLikelihood(bob::learn::misc::PLDABase& machine)
{
double llh = 0.;
// TODO: implement log likelihood computation
return llh;
}
void bob::learn::misc::PLDATrainer::enrol(bob::learn::misc::PLDAMachine& plda_machine,
const blitz::Array<double,2>& ar) const
......
......@@ -24,11 +24,12 @@ static auto EMPCATrainer_doc = bob::extension::ClassDoc(
"",
true
)
.add_prototype("compute_likelihood","")
.add_prototype("convergence_threshold","")
.add_prototype("other","")
.add_prototype("","")
.add_parameter("other", ":py:class:`bob.learn.misc.EMPCATrainer`", "A EMPCATrainer object to be copied.")
.add_parameter("convergence_threshold", "double", "")
);
......
......@@ -11,11 +11,13 @@
#ifndef BOB_LEARN_MISC_PLDA_TRAINER_H
#define BOB_LEARN_MISC_PLDA_TRAINER_H
#include <bob.learn.misc/EMTrainer.h>
#include <bob.learn.misc/PLDAMachine.h>
#include <blitz/array.h>
#include <map>
#include <boost/shared_ptr.hpp>
#include <vector>
#include <map>
#include <bob.core/array_copy.h>
#include <boost/random.hpp>
#include <boost/random/mersenne_twister.hpp>
namespace bob { namespace learn { namespace misc {
......@@ -31,8 +33,7 @@ namespace bob { namespace learn { namespace misc {
* 3. 'Probabilistic Models for Inference about Identity', Li, Fu, Mohammed,
* Elder and Prince, TPAMI'2012
*/
class PLDATrainer: public EMTrainer<bob::learn::misc::PLDABase,
std::vector<blitz::Array<double,2> > >
class PLDATrainer
{
public: //api
/**
......@@ -40,7 +41,7 @@ class PLDATrainer: public EMTrainer<bob::learn::misc::PLDABase,
* training stage will place the resulting components in the
* PLDABase.
*/
PLDATrainer(const size_t max_iterations=100, const bool use_sum_second_order=true);
PLDATrainer(const bool use_sum_second_order);
/**
* @brief Copy constructor
......@@ -70,18 +71,18 @@ class PLDATrainer: public EMTrainer<bob::learn::misc::PLDABase,
/**
* @brief Similarity operator
*/
virtual bool is_similar_to(const PLDATrainer& b,
bool is_similar_to(const PLDATrainer& b,
const double r_epsilon=1e-5, const double a_epsilon=1e-8) const;
/**
* @brief Performs some initialization before the E- and M-steps.
*/
virtual void initialize(bob::learn::misc::PLDABase& machine,
void initialize(bob::learn::misc::PLDABase& machine,
const std::vector<blitz::Array<double,2> >& v_ar);
/**
* @brief Performs some actions after the end of the E- and M-steps.
*/
virtual void finalize(bob::learn::misc::PLDABase& machine,
void finalize(bob::learn::misc::PLDABase& machine,
const std::vector<blitz::Array<double,2> >& v_ar);
/**
......@@ -89,21 +90,16 @@ class PLDATrainer: public EMTrainer<bob::learn::misc::PLDABase,
* these as m_z_{first,second}_order.
* The statistics will be used in the mStep() that follows.
*/
virtual void eStep(bob::learn::misc::PLDABase& machine,
void eStep(bob::learn::misc::PLDABase& machine,
const std::vector<blitz::Array<double,2> >& v_ar);
/**
* @brief Performs a maximization step to update the parameters of the
* PLDABase
*/
virtual void mStep(bob::learn::misc::PLDABase& machine,
void mStep(bob::learn::misc::PLDABase& machine,
const std::vector<blitz::Array<double,2> >& v_ar);
/**
* @brief Computes the average log likelihood using the current estimates
* of the latent variables.
*/
virtual double computeLikelihood(bob::learn::misc::PLDABase& machine);
/**
* @brief Sets whether the second order statistics are stored during the
......@@ -223,6 +219,9 @@ class PLDATrainer: public EMTrainer<bob::learn::misc::PLDABase,
const blitz::Array<double,2>& ar) const;
private:
boost::shared_ptr<boost::mt19937> m_rng;
//representation
size_t m_dim_d; ///< Dimensionality of the input features
size_t m_dim_f; ///< Size/rank of the \f$F\f$ subspace
......
......@@ -119,6 +119,7 @@ static PyObject* create_module (void) {
if (import_bob_blitz() < 0) return 0;
if (import_bob_core_random() < 0) return 0;
if (import_bob_io_base() < 0) return 0;
//if (import_bob_learn_linear() < 0) return 0;
Py_INCREF(module);
return module;
......
......@@ -12,7 +12,9 @@
#include <bob.blitz/cleanup.h>
#include <bob.core/random_api.h>
#include <bob.io.base/api.h>
#include <bob.learn.linear/api.h>
#include <bob.extension/documentation.h>
#define BOB_LEARN_EM_MODULE
......@@ -43,6 +45,8 @@
#include <bob.learn.misc/EMPCATrainer.h>
#include <bob.learn.misc/PLDAMachine.h>
#include <bob.learn.misc/PLDATrainer.h>
#include <bob.learn.misc/ZTNorm.h>
......@@ -282,6 +286,18 @@ bool init_BobLearnMiscPLDAMachine(PyObject* module);
int PyBobLearnMiscPLDAMachine_Check(PyObject* o);
// PLDATrainer
typedef struct {
PyObject_HEAD
boost::shared_ptr<bob::learn::misc::PLDATrainer> cxx;
} PyBobLearnMiscPLDATrainerObject;
extern PyTypeObject PyBobLearnMiscPLDATrainer_Type;
bool init_BobLearnMiscPLDATrainer(PyObject* module);
int PyBobLearnMiscPLDATrainer_Check(PyObject* o);
// EMPCATrainer
typedef struct {
PyObject_HEAD
......
/**
* @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
* @date Wed 04 Feb 14:15:00 2015
*
* @brief Python API for bob::learn::em
*
* Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
*/
#include "main.h"
#include <boost/make_shared.hpp>
/******************************************************************/
/************ Constructor Section *********************************/
/******************************************************************/
static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;} /* converts PyObject to bool and returns false if object is NULL */
template <int N>
int list_as_vector(PyObject* list, std::vector<blitz::Array<double,N> >& vec)
{
for (int i=0; i<PyList_GET_SIZE(list); i++)
{
PyBlitzArrayObject* blitz_object;
if (!PyArg_Parse(PyList_GetItem(list, i), "O&", &PyBlitzArray_Converter, &blitz_object)){
PyErr_Format(PyExc_RuntimeError, "Expected numpy array object");
return -1;
}
auto blitz_object_ = make_safe(blitz_object);
vec.push_back(*PyBlitzArrayCxx_AsBlitz<double,N>(blitz_object));
}
return 0;
}
static auto PLDATrainer_doc = bob::extension::ClassDoc(
BOB_EXT_MODULE_PREFIX ".PLDATrainer",
"This class can be used to train the :math:`$F$`, :math:`$G$ and "
" :math:`$\\Sigma$` matrices and the mean vector :math:`$\\mu$` of a PLDA model.",
"References: [ElShafey2014,PrinceElder2007,LiFu2012]",
).add_constructor(
bob::extension::FunctionDoc(
"__init__",
"Default constructor.\n Initializes a new PLDA trainer. The "
"training stage will place the resulting components in the "
"PLDABase.",
"",
true
)
.add_prototype("use_sum_second_order","")
.add_prototype("other","")
.add_prototype("","")
.add_parameter("other", ":py:class:`bob.learn.misc.PLDATrainer`", "A PLDATrainer object to be copied.")
.add_parameter("use_sum_second_order", "bool", "")
);
static int PyBobLearnMiscPLDATrainer_init_copy(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
char** kwlist = PLDATrainer_doc.kwlist(1);
PyBobLearnMiscPLDATrainerObject* o;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscPLDATrainer_Type, &o)){
PLDATrainer_doc.print_usage();
return -1;
}
self->cxx.reset(new bob::learn::misc::PLDATrainer(*o->cxx));
return 0;
}
static int PyBobLearnMiscPLDATrainer_init_bool(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
char** kwlist = PLDATrainer_doc.kwlist(0);
PyObject* use_sum_second_order;
//Parsing the input argments
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBool_Type, &use_sum_second_order))
return -1;
self->cxx.reset(new bob::learn::misc::PLDATrainer(f(use_sum_second_order)));
return 0;
}
static int PyBobLearnMiscPLDATrainer_init(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
// get the number of command line arguments
int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
switch(nargs){
case 0:{
self->cxx.reset(new bob::learn::misc::PLDATrainer());
return 0;
}
case 1:{
//Reading the input argument
PyObject* arg = 0;
if (PyTuple_Size(args))
arg = PyTuple_GET_ITEM(args, 0);
else {
PyObject* tmp = PyDict_Values(kwargs);
auto tmp_ = make_safe(tmp);
arg = PyList_GET_ITEM(tmp, 0);
}
if(PyBobLearnMiscPLDATrainer_Check(arg))
// If the constructor input is PLDATrainer object
return PyBobLearnMiscPLDATrainer_init_copy(self, args, kwargs);
else
return PyBobLearnMiscPLDATrainer_init_bool(self, args, kwargs);
}
default:{
PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires only 0 or 1 argument, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs);
PLDATrainer_doc.print_usage();
return -1;
}
}
BOB_CATCH_MEMBER("cannot create PLDATrainer", 0)
return 0;
}
static void PyBobLearnMiscPLDATrainer_delete(PyBobLearnMiscPLDATrainerObject* self) {
self->cxx.reset();
Py_TYPE(self)->tp_free((PyObject*)self);
}
int PyBobLearnMiscPLDATrainer_Check(PyObject* o) {
return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscPLDATrainer_Type));
}
static PyObject* PyBobLearnMiscPLDATrainer_RichCompare(PyBobLearnMiscPLDATrainerObject* self, PyObject* other, int op) {
BOB_TRY
if (!PyBobLearnMiscPLDATrainer_Check(other)) {
PyErr_Format(PyExc_TypeError, "cannot compare `%s' with `%s'", Py_TYPE(self)->tp_name, Py_TYPE(other)->tp_name);
return 0;
}
auto other_ = reinterpret_cast<PyBobLearnMiscPLDATrainerObject*>(other);
switch (op) {
case Py_EQ:
if (*self->cxx==*other_->cxx) Py_RETURN_TRUE; else Py_RETURN_FALSE;
case Py_NE:
if (*self->cxx==*other_->cxx) Py_RETURN_FALSE; else Py_RETURN_TRUE;
default:
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
BOB_CATCH_MEMBER("cannot compare PLDATrainer objects", 0)
}
/******************************************************************/
/************ Variables Section ***********************************/
/******************************************************************/
static auto z_second_order = bob::extension::VariableDoc(
"z_second_order",
"array_like <float, 3D>",
"",
""
);
PyObject* PyBobLearnMiscPLDATrainer_get_z_second_order(PyBobLearnMiscPLDATrainerObject* self, void*){
BOB_TRY
return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getZSecondOrder());
BOB_CATCH_MEMBER("z_second_order could not be read", 0)
}
static auto z_second_order_sum = bob::extension::VariableDoc(
"z_second_order_sum",
"array_like <float, 2D>",
"",
""
);
PyObject* PyBobLearnMiscPLDATrainer_get_z_second_order_sum(PyBobLearnMiscPLDATrainerObject* self, void*){
BOB_TRY
return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getZSecondOrderSum());
BOB_CATCH_MEMBER("z_second_order_sum could not be read", 0)
}
static auto z_first_order = bob::extension::VariableDoc(
"z_first_order",
"array_like <float, 2D>",
"",
""
);
PyObject* PyBobLearnMiscPLDATrainer_get_z_first_order(PyBobLearnMiscPLDATrainerObject* self, void*){
BOB_TRY
return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getZFirstOrder());
BOB_CATCH_MEMBER("z_first_order could not be read", 0)
}
static PyGetSetDef PyBobLearnMiscPLDATrainer_getseters[] = {
{
z_first_order.name(),
(getter)PyBobLearnMiscPLDATrainer_get_z_first_order,
0,
z_first_order.doc(),
0
},
{
z_second_order_sum.name(),
(getter)PyBobLearnMiscPLDATrainer_get_z_second_order_sum,
0,
z_second_order_sum.doc(),
0
},
{
z_second_order.name(),
(getter)PyBobLearnMiscPLDATrainer_get_z_second_order,
0,
z_second_order.doc(),
0
},
{0} // Sentinel
};
/******************************************************************/
/************ Functions Section ***********************************/
/******************************************************************/
/*** initialize ***/
static auto initialize = bob::extension::FunctionDoc(
"initialize",
"Initialization before the EM steps",
"",
true
)
.add_prototype("plda_base,data")
.add_parameter("plda_base", ":py:class:`bob.learn.misc.PLDABase`", "PLDAMachine Object")
.add_parameter("data", "list", "");
static PyObject* PyBobLearnMiscPLDATrainer_initialize(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
/* Parses input arguments in a single shot */
char** kwlist = initialize.kwlist(0);
PyBobLearnMiscPLDABaseObject* plda_base = 0;
PyObject* data = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!", kwlist, &PyBobLearnMiscPLDABase_Type, &plda_base,
&PyList_Type, &data)) Py_RETURN_NONE;
std::vector<blitz::Array<double,2> > data_vector;
if(list_as_vector(data ,data_vector)==0)
self->cxx->initialize(*plda_machine->cxx, data_vector);
BOB_CATCH_MEMBER("cannot perform the initialize method", 0)
Py_RETURN_NONE;
}
/*** e_step ***/
static auto e_step = bob::extension::FunctionDoc(
"e_step",
"e_step before the EM steps",
"",
true
)
.add_prototype("plda_base,data")
.add_parameter("plda_base", ":py:class:`bob.learn.misc.PLDABase`", "PLDAMachine Object")
.add_parameter("data", "list", "");
static PyObject* PyBobLearnMiscPLDATrainer_e_step(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
/* Parses input arguments in a single shot */
char** kwlist = e_step.kwlist(0);
PyBobLearnMiscPLDABaseObject* plda_base = 0;
PyObject* data = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!", kwlist, &PyBobLearnMiscPLDABase_Type, &plda_base,
&PyList_Type, &data)) Py_RETURN_NONE;
std::vector<blitz::Array<double,2> > data_vector;
if(list_as_vector(data ,data_vector)==0)
self->cxx->e_step(*plda_machine->cxx, data_vector);
BOB_CATCH_MEMBER("cannot perform the e_step method", 0)
Py_RETURN_NONE;
}
/*** m_step ***/
static auto m_step = bob::