Commit 1b855768 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Binded ISVTrainer

parent a9c9d87e
......@@ -15,6 +15,7 @@ from .__kmeans_trainer__ import *
from .__ML_gmm_trainer__ import *
from .__MAP_gmm_trainer__ import *
from .__jfa_trainer__ import *
from .__isv_trainer__ import *
def ztnorm_same_value(vect_a, vect_b):
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# Mon Fev 02 21:40:10 2015 +0200
#
# Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland
from ._library import _ISVTrainer
import numpy
# define the class
class ISVTrainer (_ISVTrainer):
def __init__(self, max_iterations=10, relevance_factor=4., convergence_threshold = 0.001):
"""
:py:class:`bob.learn.misc.ISVTrainer` constructor
Keyword Parameters:
max_iterations
Number of maximum iterations
"""
_ISVTrainer.__init__(self, relevance_factor, convergence_threshold)
self._max_iterations = max_iterations
def train(self, isv_base, data):
"""
Train the :py:class:`bob.learn.misc.ISVBase` using data
Keyword Parameters:
jfa_base
The `:py:class:bob.learn.misc.ISVBase` class
data
The data to be trained
"""
#Initialization
self.initialize(isv_base, data);
for i in range(self._max_iterations):
#eStep
self.eStep(isv_base, data);
#mStep
self.mStep(isv_base);
# copy the documentation from the base class
__doc__ = _ISVTrainer.__doc__
......@@ -19,42 +19,38 @@
//////////////////////////// ISVTrainer ///////////////////////////
bob::learn::misc::ISVTrainer::ISVTrainer(const size_t max_iterations, const double relevance_factor):
EMTrainer<bob::learn::misc::ISVBase, std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > >
(0.001, max_iterations, false),
m_relevance_factor(relevance_factor)
{
}
bob::learn::misc::ISVTrainer::ISVTrainer(const double relevance_factor, const double convergence_threshold):
m_relevance_factor(relevance_factor),
m_convergence_threshold(convergence_threshold),
m_rng(new boost::mt19937())
{}
bob::learn::misc::ISVTrainer::ISVTrainer(const bob::learn::misc::ISVTrainer& other):
EMTrainer<bob::learn::misc::ISVBase, std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > >
(other.m_convergence_threshold, other.m_max_iterations,
other.m_compute_likelihood),
m_relevance_factor(other.m_relevance_factor)
{
}
m_convergence_threshold(other.m_convergence_threshold),
m_relevance_factor(other.m_relevance_factor),
m_rng(other.m_rng)
{}
bob::learn::misc::ISVTrainer::~ISVTrainer()
{
}
{}
bob::learn::misc::ISVTrainer& bob::learn::misc::ISVTrainer::operator=
(const bob::learn::misc::ISVTrainer& other)
{
if (this != &other)
{
bob::learn::misc::EMTrainer<bob::learn::misc::ISVBase,
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > >::operator=(other);
m_relevance_factor = other.m_relevance_factor;
m_convergence_threshold = other.m_convergence_threshold;
m_rng = other.m_rng;
m_relevance_factor = other.m_relevance_factor;
}
return *this;
}
bool bob::learn::misc::ISVTrainer::operator==(const bob::learn::misc::ISVTrainer& b) const
{
return bob::learn::misc::EMTrainer<bob::learn::misc::ISVBase,
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > >::operator==(b) &&
m_relevance_factor == b.m_relevance_factor;
return m_convergence_threshold == b.m_convergence_threshold &&
m_rng == b.m_rng &&
m_relevance_factor == b.m_relevance_factor;
}
bool bob::learn::misc::ISVTrainer::operator!=(const bob::learn::misc::ISVTrainer& b) const
......@@ -65,9 +61,9 @@ bool bob::learn::misc::ISVTrainer::operator!=(const bob::learn::misc::ISVTrainer
bool bob::learn::misc::ISVTrainer::is_similar_to(const bob::learn::misc::ISVTrainer& b,
const double r_epsilon, const double a_epsilon) const
{
return bob::learn::misc::EMTrainer<bob::learn::misc::ISVBase,
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > >::is_similar_to(b, r_epsilon, a_epsilon) &&
m_relevance_factor == b.m_relevance_factor;
return m_convergence_threshold == b.m_convergence_threshold &&
m_rng == b.m_rng &&
m_relevance_factor == b.m_relevance_factor;
}
void bob::learn::misc::ISVTrainer::initialize(bob::learn::misc::ISVBase& machine,
......@@ -89,11 +85,6 @@ void bob::learn::misc::ISVTrainer::initializeD(bob::learn::misc::ISVBase& machin
d = sqrt(machine.getBase().getUbmVariance() / m_relevance_factor);
}
void bob::learn::misc::ISVTrainer::finalize(bob::learn::misc::ISVBase& machine,
const std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& ar)
{
}
void bob::learn::misc::ISVTrainer::eStep(bob::learn::misc::ISVBase& machine,
const std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& ar)
{
......@@ -105,8 +96,7 @@ void bob::learn::misc::ISVTrainer::eStep(bob::learn::misc::ISVBase& machine,
m_base_trainer.computeAccumulatorsU(base, ar);
}
void bob::learn::misc::ISVTrainer::mStep(bob::learn::misc::ISVBase& machine,
const std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& ar)
void bob::learn::misc::ISVTrainer::mStep(bob::learn::misc::ISVBase& machine)
{
blitz::Array<double,2>& U = machine.updateU();
m_base_trainer.updateU(U);
......
......@@ -19,7 +19,7 @@
//////////////////////////// JFATrainer ///////////////////////////
bob::learn::misc::JFATrainer::JFATrainer(const size_t max_iterations):
bob::learn::misc::JFATrainer::JFATrainer():
m_rng(new boost::mt19937())
{}
......
......@@ -25,14 +25,13 @@
namespace bob { namespace learn { namespace misc {
class ISVTrainer
{
public:
/**
* @brief Constructor
*/
ISVTrainer(const size_t max_iterations=10, const double relevance_factor=4.);
ISVTrainer(const double relevance_factor=4., const double convergence_threshold = 0.001);
/**
* @brief Copy onstructor
......@@ -70,11 +69,6 @@ class ISVTrainer
*/
virtual void initialize(bob::learn::misc::ISVBase& machine,
const std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& ar);
/**
* @brief This methods performs some actions after the EM loop.
*/
virtual void finalize(bob::learn::misc::ISVBase& machine,
const std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& ar);
/**
* @brief Calculates and saves statistics across the dataset
......@@ -87,8 +81,7 @@ class ISVTrainer
* @brief Performs a maximization step to update the parameters of the
* factor analysis model.
*/
virtual void mStep(bob::learn::misc::ISVBase& machine,
const std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& ar);
virtual void mStep(bob::learn::misc::ISVBase& machine);
/**
* @brief Computes the average log likelihood using the current estimates
......@@ -150,7 +143,12 @@ class ISVTrainer
// Attributes
bob::learn::misc::FABaseTrainer m_base_trainer;
double m_relevance_factor;
double m_convergence_threshold; ///< convergence threshold
boost::shared_ptr<boost::mt19937> m_rng; ///< The random number generator for the inialization};
};
} } } // namespaces
......
......@@ -31,7 +31,7 @@ class JFATrainer
/**
* @brief Constructor
*/
JFATrainer(const size_t max_iterations=10);
JFATrainer();
/**
* @brief Copy onstructor
......
/**
* @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
* @date Mon 02 Fev 20:20:00 2015
*
* @brief Python API for bob::learn::em
*
* Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
*/
#include "main.h"
#include <boost/make_shared.hpp>
/******************************************************************/
/************ Constructor Section *********************************/
/******************************************************************/
static int extract_GMMStats_1d(PyObject *list,
std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> >& training_data)
{
for (int i=0; i<PyList_GET_SIZE(list); i++){
PyBobLearnMiscGMMStatsObject* stats;
if (!PyArg_Parse(PyList_GetItem(list, i), "O!", &PyBobLearnMiscGMMStats_Type, &stats)){
PyErr_Format(PyExc_RuntimeError, "Expected GMMStats objects");
return -1;
}
training_data.push_back(stats->cxx);
}
return 0;
}
static int extract_GMMStats_2d(PyObject *list,
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& training_data)
{
for (int i=0; i<PyList_GET_SIZE(list); i++)
{
PyObject* another_list;
PyArg_Parse(PyList_GetItem(list, i), "O!", &PyList_Type, &another_list);
std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > another_training_data;
for (int j=0; j<PyList_GET_SIZE(another_list); j++){
PyBobLearnMiscGMMStatsObject* stats;
if (!PyArg_Parse(PyList_GetItem(another_list, j), "O!", &PyBobLearnMiscGMMStats_Type, &stats)){
PyErr_Format(PyExc_RuntimeError, "Expected GMMStats objects");
return -1;
}
another_training_data.push_back(stats->cxx);
}
training_data.push_back(another_training_data);
}
return 0;
}
template <int N>
static PyObject* vector_as_list(const std::vector<blitz::Array<double,N> >& vec)
{
PyObject* list = PyList_New(vec.size());
for(size_t i=0; i<vec.size(); i++){
blitz::Array<double,N> numpy_array = vec[i];
PyObject* numpy_py_object = PyBlitzArrayCxx_AsNumpy(numpy_array);
PyList_SET_ITEM(list, i, numpy_py_object);
}
return list;
}
template <int N>
int list_as_vector(PyObject* list, std::vector<blitz::Array<double,N> >& vec)
{
for (int i=0; i<PyList_GET_SIZE(list); i++)
{
PyBlitzArrayObject* blitz_object;
if (!PyArg_Parse(PyList_GetItem(list, i), "O&", &PyBlitzArray_Converter, &blitz_object)){
PyErr_Format(PyExc_RuntimeError, "Expected numpy array object");
return -1;
}
auto blitz_object_ = make_safe(blitz_object);
vec.push_back(*PyBlitzArrayCxx_AsBlitz<double,N>(blitz_object));
}
return 0;
}
static auto ISVTrainer_doc = bob::extension::ClassDoc(
BOB_EXT_MODULE_PREFIX ".ISVTrainer",
"ISVTrainer"
"References: [Vogt2008,McCool2013]",
""
).add_constructor(
bob::extension::FunctionDoc(
"__init__",
"Constructor. Builds a new ISVTrainer",
"",
true
)
.add_prototype("relevance_factor,convergence_threshold","")
.add_prototype("other","")
.add_prototype("","")
.add_parameter("other", ":py:class:`bob.learn.misc.ISVTrainer`", "A ISVTrainer object to be copied.")
.add_parameter("relevance_factor", "double", "")
.add_parameter("convergence_threshold", "double", "")
);
static int PyBobLearnMiscISVTrainer_init_copy(PyBobLearnMiscISVTrainerObject* self, PyObject* args, PyObject* kwargs) {
char** kwlist = ISVTrainer_doc.kwlist(1);
PyBobLearnMiscISVTrainerObject* o;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscISVTrainer_Type, &o)){
ISVTrainer_doc.print_usage();
return -1;
}
self->cxx.reset(new bob::learn::misc::ISVTrainer(*o->cxx));
return 0;
}
static int PyBobLearnMiscISVTrainer_init_number(PyBobLearnMiscISVTrainerObject* self, PyObject* args, PyObject* kwargs) {
char** kwlist = ISVTrainer_doc.kwlist(0);
double relevance_factor = 4.;
double convergence_threshold = 0.001;
//Parsing the input argments
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "dd", kwlist, &relevance_factor, &convergence_threshold))
return -1;
if(relevance_factor < 0){
PyErr_Format(PyExc_TypeError, "gaussians argument must be greater than zero");
return -1;
}
if(convergence_threshold < 0){
PyErr_Format(PyExc_TypeError, "convergence_threshold argument must be greater than zero");
return -1;
}
self->cxx.reset(new bob::learn::misc::ISVTrainer(relevance_factor, convergence_threshold));
return 0;
}
static int PyBobLearnMiscISVTrainer_init(PyBobLearnMiscISVTrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
// get the number of command line arguments
int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
switch(nargs){
case 0:{
self->cxx.reset(new bob::learn::misc::ISVTrainer());
return 0;
}
case 1:{
// If the constructor input is ISVTrainer object
return PyBobLearnMiscISVTrainer_init_copy(self, args, kwargs);
}
case 2:{
// If the constructor input is ISVTrainer object
return PyBobLearnMiscISVTrainer_init_number(self, args, kwargs);
}
default:{
PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires only 0, 1 or 2 arguments, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs);
ISVTrainer_doc.print_usage();
return -1;
}
}
BOB_CATCH_MEMBER("cannot create ISVTrainer", 0)
return 0;
}
static void PyBobLearnMiscISVTrainer_delete(PyBobLearnMiscISVTrainerObject* self) {
self->cxx.reset();
Py_TYPE(self)->tp_free((PyObject*)self);
}
int PyBobLearnMiscISVTrainer_Check(PyObject* o) {
return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscISVTrainer_Type));
}
static PyObject* PyBobLearnMiscISVTrainer_RichCompare(PyBobLearnMiscISVTrainerObject* self, PyObject* other, int op) {
BOB_TRY
if (!PyBobLearnMiscISVTrainer_Check(other)) {
PyErr_Format(PyExc_TypeError, "cannot compare `%s' with `%s'", Py_TYPE(self)->tp_name, Py_TYPE(other)->tp_name);
return 0;
}
auto other_ = reinterpret_cast<PyBobLearnMiscISVTrainerObject*>(other);
switch (op) {
case Py_EQ:
if (*self->cxx==*other_->cxx) Py_RETURN_TRUE; else Py_RETURN_FALSE;
case Py_NE:
if (*self->cxx==*other_->cxx) Py_RETURN_FALSE; else Py_RETURN_TRUE;
default:
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
BOB_CATCH_MEMBER("cannot compare ISVTrainer objects", 0)
}
/******************************************************************/
/************ Variables Section ***********************************/
/******************************************************************/
static auto acc_u_a1 = bob::extension::VariableDoc(
"acc_u_a1",
"array_like <float, 3D>",
"Accumulator updated during the E-step",
""
);
PyObject* PyBobLearnMiscISVTrainer_get_acc_u_a1(PyBobLearnMiscISVTrainerObject* self, void*){
BOB_TRY
return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getAccUA1());
BOB_CATCH_MEMBER("acc_u_a1 could not be read", 0)
}
int PyBobLearnMiscISVTrainer_set_acc_u_a1(PyBobLearnMiscISVTrainerObject* self, PyObject* value, void*){
BOB_TRY
PyBlitzArrayObject* o;
if (!PyBlitzArray_Converter(value, &o)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects a 3D array of floats", Py_TYPE(self)->tp_name, acc_u_a1.name());
return -1;
}
auto o_ = make_safe(o);
auto b = PyBlitzArrayCxx_AsBlitz<double,3>(o, "acc_u_a1");
if (!b) return -1;
self->cxx->setAccUA1(*b);
return 0;
BOB_CATCH_MEMBER("acc_u_a1 could not be set", -1)
}
static auto acc_u_a2 = bob::extension::VariableDoc(
"acc_u_a2",
"array_like <float, 2D>",
"Accumulator updated during the E-step",
""
);
PyObject* PyBobLearnMiscISVTrainer_get_acc_u_a2(PyBobLearnMiscISVTrainerObject* self, void*){
BOB_TRY
return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getAccUA2());
BOB_CATCH_MEMBER("acc_u_a2 could not be read", 0)
}
int PyBobLearnMiscISVTrainer_set_acc_u_a2(PyBobLearnMiscISVTrainerObject* self, PyObject* value, void*){
BOB_TRY
PyBlitzArrayObject* o;
if (!PyBlitzArray_Converter(value, &o)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects a 2D array of floats", Py_TYPE(self)->tp_name, acc_u_a2.name());
return -1;
}
auto o_ = make_safe(o);
auto b = PyBlitzArrayCxx_AsBlitz<double,2>(o, "acc_u_a2");
if (!b) return -1;
self->cxx->setAccUA2(*b);
return 0;
BOB_CATCH_MEMBER("acc_u_a2 could not be set", -1)
}
static auto __X__ = bob::extension::VariableDoc(
"__X__",
"list",
"",
""
);
PyObject* PyBobLearnMiscISVTrainer_get_X(PyBobLearnMiscISVTrainerObject* self, void*){
BOB_TRY
return vector_as_list(self->cxx->getX());
BOB_CATCH_MEMBER("__X__ could not be read", 0)
}
int PyBobLearnMiscISVTrainer_set_X(PyBobLearnMiscISVTrainerObject* self, PyObject* value, void*){
BOB_TRY
// Parses input arguments in a single shot
if (!PyList_Check(value)){
PyErr_Format(PyExc_TypeError, "Expected a list in `%s'", __X__.name());
return -1;
}
std::vector<blitz::Array<double,2> > data;
if(list_as_vector(value ,data)==0){
self->cxx->setX(data);
}
return 0;
BOB_CATCH_MEMBER("__X__ could not be written", 0)
}
static auto __Z__ = bob::extension::VariableDoc(
"__Z__",
"list",
"",
""
);
PyObject* PyBobLearnMiscISVTrainer_get_Z(PyBobLearnMiscISVTrainerObject* self, void*){
BOB_TRY
return vector_as_list(self->cxx->getZ());
BOB_CATCH_MEMBER("__Z__ could not be read", 0)
}
int PyBobLearnMiscISVTrainer_set_Z(PyBobLearnMiscISVTrainerObject* self, PyObject* value, void*){
BOB_TRY
// Parses input arguments in a single shot
if (!PyList_Check(value)){
PyErr_Format(PyExc_TypeError, "Expected a list in `%s'", __Z__.name());
return -1;
}
std::vector<blitz::Array<double,1> > data;
if(list_as_vector(value ,data)==0){
self->cxx->setZ(data);
}
return 0;
BOB_CATCH_MEMBER("__Z__ could not be written", 0)
}
static PyGetSetDef PyBobLearnMiscISVTrainer_getseters[] = {
{
acc_u_a1.name(),
(getter)PyBobLearnMiscISVTrainer_get_acc_u_a1,
(setter)PyBobLearnMiscISVTrainer_get_acc_u_a1,
acc_u_a1.doc(),
0
},
{
acc_u_a2.name(),
(getter)PyBobLearnMiscISVTrainer_get_acc_u_a2,
(setter)PyBobLearnMiscISVTrainer_get_acc_u_a2,
acc_u_a2.doc(),
0
},
{
__X__.name(),
(getter)PyBobLearnMiscISVTrainer_get_X,
(setter)PyBobLearnMiscISVTrainer_set_X,
__X__.doc(),
0
},
{
__Z__.name(),
(getter)PyBobLearnMiscISVTrainer_get_Z,
(setter)PyBobLearnMiscISVTrainer_set_Z,
__Z__.doc(),
0
},
{0} // Sentinel
};
/******************************************************************/
/************ Functions Section ***********************************/
/******************************************************************/
/*** initialize ***/
static auto initialize = bob::extension::FunctionDoc(
"initialize",
"Initialization before the EM steps",
"",
true
)
.add_prototype("isv_base,stats")
.add_parameter("isv_base", ":py:class:`bob.learn.misc.ISVBase`", "ISVBase Object")
.add_parameter("stats", ":py:class:`bob.learn.misc.GMMStats`", "GMMStats Object");