diff --git a/bob/learn/misc/__init__.py b/bob/learn/misc/__init__.py index 877a303f7cef35896eaa07d086a3cc0098352ba5..2c3dd9d94c29e36af8fe91c53e8e4f3151813366 100644 --- a/bob/learn/misc/__init__.py +++ b/bob/learn/misc/__init__.py @@ -15,6 +15,7 @@ from .__kmeans_trainer__ import * from .__ML_gmm_trainer__ import * from .__MAP_gmm_trainer__ import * from .__jfa_trainer__ import * +from .__isv_trainer__ import * def ztnorm_same_value(vect_a, vect_b): diff --git a/bob/learn/misc/__isv_trainer__.py b/bob/learn/misc/__isv_trainer__.py new file mode 100644 index 0000000000000000000000000000000000000000..b77e06d63ee3ec8d47b6a6b761cfe10671a142fe --- /dev/null +++ b/bob/learn/misc/__isv_trainer__.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# Tiago de Freitas Pereira <tiago.pereira@idiap.ch> +# Mon Fev 02 21:40:10 2015 +0200 +# +# Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland + +from ._library import _ISVTrainer +import numpy + +# define the class +class ISVTrainer (_ISVTrainer): + + def __init__(self, max_iterations=10, relevance_factor=4., convergence_threshold = 0.001): + """ + :py:class:`bob.learn.misc.ISVTrainer` constructor + + Keyword Parameters: + max_iterations + Number of maximum iterations + """ + _ISVTrainer.__init__(self, relevance_factor, convergence_threshold) + self._max_iterations = max_iterations + + + def train(self, isv_base, data): + """ + Train the :py:class:`bob.learn.misc.ISVBase` using data + + Keyword Parameters: + jfa_base + The `:py:class:bob.learn.misc.ISVBase` class + data + The data to be trained + """ + + #Initialization + self.initialize(isv_base, data); + + for i in range(self._max_iterations): + #eStep + self.eStep(isv_base, data); + #mStep + self.mStep(isv_base); + + + +# copy the documentation from the base class +__doc__ = _ISVTrainer.__doc__ diff --git a/bob/learn/misc/cpp/ISVTrainer.cpp b/bob/learn/misc/cpp/ISVTrainer.cpp index d0455e96ab66b214a55eeec1c986eb410c6557b9..c7891b9f15dd8fd8ae4ef86e59283e6050849a5e 100644 --- a/bob/learn/misc/cpp/ISVTrainer.cpp +++ b/bob/learn/misc/cpp/ISVTrainer.cpp @@ -19,42 +19,38 @@ //////////////////////////// ISVTrainer /////////////////////////// -bob::learn::misc::ISVTrainer::ISVTrainer(const size_t max_iterations, const double relevance_factor): - EMTrainer<bob::learn::misc::ISVBase, std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > > - (0.001, max_iterations, false), - m_relevance_factor(relevance_factor) -{ -} +bob::learn::misc::ISVTrainer::ISVTrainer(const double relevance_factor, const double convergence_threshold): + m_relevance_factor(relevance_factor), + m_convergence_threshold(convergence_threshold), + m_rng(new boost::mt19937()) +{} bob::learn::misc::ISVTrainer::ISVTrainer(const bob::learn::misc::ISVTrainer& other): - EMTrainer<bob::learn::misc::ISVBase, std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > > - (other.m_convergence_threshold, other.m_max_iterations, - other.m_compute_likelihood), - m_relevance_factor(other.m_relevance_factor) -{ -} + m_convergence_threshold(other.m_convergence_threshold), + m_relevance_factor(other.m_relevance_factor), + m_rng(other.m_rng) +{} bob::learn::misc::ISVTrainer::~ISVTrainer() -{ -} +{} bob::learn::misc::ISVTrainer& bob::learn::misc::ISVTrainer::operator= (const bob::learn::misc::ISVTrainer& other) { if (this != &other) { - bob::learn::misc::EMTrainer<bob::learn::misc::ISVBase, - std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > >::operator=(other); - m_relevance_factor = other.m_relevance_factor; + m_convergence_threshold = other.m_convergence_threshold; + m_rng = other.m_rng; + m_relevance_factor = other.m_relevance_factor; } return *this; } bool bob::learn::misc::ISVTrainer::operator==(const bob::learn::misc::ISVTrainer& b) const { - return bob::learn::misc::EMTrainer<bob::learn::misc::ISVBase, - std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > >::operator==(b) && - m_relevance_factor == b.m_relevance_factor; + return m_convergence_threshold == b.m_convergence_threshold && + m_rng == b.m_rng && + m_relevance_factor == b.m_relevance_factor; } bool bob::learn::misc::ISVTrainer::operator!=(const bob::learn::misc::ISVTrainer& b) const @@ -65,9 +61,9 @@ bool bob::learn::misc::ISVTrainer::operator!=(const bob::learn::misc::ISVTrainer bool bob::learn::misc::ISVTrainer::is_similar_to(const bob::learn::misc::ISVTrainer& b, const double r_epsilon, const double a_epsilon) const { - return bob::learn::misc::EMTrainer<bob::learn::misc::ISVBase, - std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > >::is_similar_to(b, r_epsilon, a_epsilon) && - m_relevance_factor == b.m_relevance_factor; + return m_convergence_threshold == b.m_convergence_threshold && + m_rng == b.m_rng && + m_relevance_factor == b.m_relevance_factor; } void bob::learn::misc::ISVTrainer::initialize(bob::learn::misc::ISVBase& machine, @@ -89,11 +85,6 @@ void bob::learn::misc::ISVTrainer::initializeD(bob::learn::misc::ISVBase& machin d = sqrt(machine.getBase().getUbmVariance() / m_relevance_factor); } -void bob::learn::misc::ISVTrainer::finalize(bob::learn::misc::ISVBase& machine, - const std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& ar) -{ -} - void bob::learn::misc::ISVTrainer::eStep(bob::learn::misc::ISVBase& machine, const std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& ar) { @@ -105,8 +96,7 @@ void bob::learn::misc::ISVTrainer::eStep(bob::learn::misc::ISVBase& machine, m_base_trainer.computeAccumulatorsU(base, ar); } -void bob::learn::misc::ISVTrainer::mStep(bob::learn::misc::ISVBase& machine, - const std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& ar) +void bob::learn::misc::ISVTrainer::mStep(bob::learn::misc::ISVBase& machine) { blitz::Array<double,2>& U = machine.updateU(); m_base_trainer.updateU(U); diff --git a/bob/learn/misc/cpp/JFATrainer.cpp b/bob/learn/misc/cpp/JFATrainer.cpp index 6b79da8b9d609ed88bbf1b363b751d54f42f7786..ecb23a2b0b3e6a64bf6af871b4b15d7781a70d0d 100644 --- a/bob/learn/misc/cpp/JFATrainer.cpp +++ b/bob/learn/misc/cpp/JFATrainer.cpp @@ -19,7 +19,7 @@ //////////////////////////// JFATrainer /////////////////////////// -bob::learn::misc::JFATrainer::JFATrainer(const size_t max_iterations): +bob::learn::misc::JFATrainer::JFATrainer(): m_rng(new boost::mt19937()) {} diff --git a/bob/learn/misc/include/bob.learn.misc/ISVTrainer.h b/bob/learn/misc/include/bob.learn.misc/ISVTrainer.h index e6b5285fd0dd8355a448e9c860251f62619671da..1f98e7477960e799979e41982c4651c2b0d17c98 100644 --- a/bob/learn/misc/include/bob.learn.misc/ISVTrainer.h +++ b/bob/learn/misc/include/bob.learn.misc/ISVTrainer.h @@ -25,14 +25,13 @@ namespace bob { namespace learn { namespace misc { - class ISVTrainer { public: /** * @brief Constructor */ - ISVTrainer(const size_t max_iterations=10, const double relevance_factor=4.); + ISVTrainer(const double relevance_factor=4., const double convergence_threshold = 0.001); /** * @brief Copy onstructor @@ -70,11 +69,6 @@ class ISVTrainer */ virtual void initialize(bob::learn::misc::ISVBase& machine, const std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& ar); - /** - * @brief This methods performs some actions after the EM loop. - */ - virtual void finalize(bob::learn::misc::ISVBase& machine, - const std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& ar); /** * @brief Calculates and saves statistics across the dataset @@ -87,8 +81,7 @@ class ISVTrainer * @brief Performs a maximization step to update the parameters of the * factor analysis model. */ - virtual void mStep(bob::learn::misc::ISVBase& machine, - const std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& ar); + virtual void mStep(bob::learn::misc::ISVBase& machine); /** * @brief Computes the average log likelihood using the current estimates @@ -150,7 +143,12 @@ class ISVTrainer // Attributes bob::learn::misc::FABaseTrainer m_base_trainer; + double m_relevance_factor; + + double m_convergence_threshold; ///< convergence threshold + + boost::shared_ptr<boost::mt19937> m_rng; ///< The random number generator for the inialization}; }; } } } // namespaces diff --git a/bob/learn/misc/include/bob.learn.misc/JFATrainer.h b/bob/learn/misc/include/bob.learn.misc/JFATrainer.h index f8c818e9fd2c0995afa51faef71bc63092ca67a9..99070b5b2ae6299c7ce2354e8f84783e456bd46b 100644 --- a/bob/learn/misc/include/bob.learn.misc/JFATrainer.h +++ b/bob/learn/misc/include/bob.learn.misc/JFATrainer.h @@ -31,7 +31,7 @@ class JFATrainer /** * @brief Constructor */ - JFATrainer(const size_t max_iterations=10); + JFATrainer(); /** * @brief Copy onstructor diff --git a/bob/learn/misc/isv_trainer.cpp b/bob/learn/misc/isv_trainer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f288fbb557c8c3729f502299ce92631fe9607735 --- /dev/null +++ b/bob/learn/misc/isv_trainer.cpp @@ -0,0 +1,561 @@ +/** + * @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + * @date Mon 02 Fev 20:20:00 2015 + * + * @brief Python API for bob::learn::em + * + * Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland + */ + +#include "main.h" +#include <boost/make_shared.hpp> + +/******************************************************************/ +/************ Constructor Section *********************************/ +/******************************************************************/ + +static int extract_GMMStats_1d(PyObject *list, + std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> >& training_data) +{ + for (int i=0; i<PyList_GET_SIZE(list); i++){ + + PyBobLearnMiscGMMStatsObject* stats; + if (!PyArg_Parse(PyList_GetItem(list, i), "O!", &PyBobLearnMiscGMMStats_Type, &stats)){ + PyErr_Format(PyExc_RuntimeError, "Expected GMMStats objects"); + return -1; + } + training_data.push_back(stats->cxx); + } + return 0; +} + +static int extract_GMMStats_2d(PyObject *list, + std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& training_data) +{ + for (int i=0; i<PyList_GET_SIZE(list); i++) + { + PyObject* another_list; + PyArg_Parse(PyList_GetItem(list, i), "O!", &PyList_Type, &another_list); + + std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > another_training_data; + for (int j=0; j<PyList_GET_SIZE(another_list); j++){ + + PyBobLearnMiscGMMStatsObject* stats; + if (!PyArg_Parse(PyList_GetItem(another_list, j), "O!", &PyBobLearnMiscGMMStats_Type, &stats)){ + PyErr_Format(PyExc_RuntimeError, "Expected GMMStats objects"); + return -1; + } + another_training_data.push_back(stats->cxx); + } + training_data.push_back(another_training_data); + } + return 0; +} + +template <int N> +static PyObject* vector_as_list(const std::vector<blitz::Array<double,N> >& vec) +{ + PyObject* list = PyList_New(vec.size()); + for(size_t i=0; i<vec.size(); i++){ + blitz::Array<double,N> numpy_array = vec[i]; + PyObject* numpy_py_object = PyBlitzArrayCxx_AsNumpy(numpy_array); + PyList_SET_ITEM(list, i, numpy_py_object); + } + return list; +} + +template <int N> +int list_as_vector(PyObject* list, std::vector<blitz::Array<double,N> >& vec) +{ + for (int i=0; i<PyList_GET_SIZE(list); i++) + { + PyBlitzArrayObject* blitz_object; + if (!PyArg_Parse(PyList_GetItem(list, i), "O&", &PyBlitzArray_Converter, &blitz_object)){ + PyErr_Format(PyExc_RuntimeError, "Expected numpy array object"); + return -1; + } + auto blitz_object_ = make_safe(blitz_object); + vec.push_back(*PyBlitzArrayCxx_AsBlitz<double,N>(blitz_object)); + } + return 0; +} + + + +static auto ISVTrainer_doc = bob::extension::ClassDoc( + BOB_EXT_MODULE_PREFIX ".ISVTrainer", + "ISVTrainer" + "References: [Vogt2008,McCool2013]", + "" +).add_constructor( + bob::extension::FunctionDoc( + "__init__", + "Constructor. Builds a new ISVTrainer", + "", + true + ) + .add_prototype("relevance_factor,convergence_threshold","") + .add_prototype("other","") + .add_prototype("","") + .add_parameter("other", ":py:class:`bob.learn.misc.ISVTrainer`", "A ISVTrainer object to be copied.") + .add_parameter("relevance_factor", "double", "") + .add_parameter("convergence_threshold", "double", "") +); + + +static int PyBobLearnMiscISVTrainer_init_copy(PyBobLearnMiscISVTrainerObject* self, PyObject* args, PyObject* kwargs) { + + char** kwlist = ISVTrainer_doc.kwlist(1); + PyBobLearnMiscISVTrainerObject* o; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscISVTrainer_Type, &o)){ + ISVTrainer_doc.print_usage(); + return -1; + } + + self->cxx.reset(new bob::learn::misc::ISVTrainer(*o->cxx)); + return 0; +} + + +static int PyBobLearnMiscISVTrainer_init_number(PyBobLearnMiscISVTrainerObject* self, PyObject* args, PyObject* kwargs) { + + char** kwlist = ISVTrainer_doc.kwlist(0); + double relevance_factor = 4.; + double convergence_threshold = 0.001; + //Parsing the input argments + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "dd", kwlist, &relevance_factor, &convergence_threshold)) + return -1; + + if(relevance_factor < 0){ + PyErr_Format(PyExc_TypeError, "gaussians argument must be greater than zero"); + return -1; + } + + if(convergence_threshold < 0){ + PyErr_Format(PyExc_TypeError, "convergence_threshold argument must be greater than zero"); + return -1; + } + + self->cxx.reset(new bob::learn::misc::ISVTrainer(relevance_factor, convergence_threshold)); + return 0; +} + + +static int PyBobLearnMiscISVTrainer_init(PyBobLearnMiscISVTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + // get the number of command line arguments + int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0); + + switch(nargs){ + case 0:{ + self->cxx.reset(new bob::learn::misc::ISVTrainer()); + return 0; + } + case 1:{ + // If the constructor input is ISVTrainer object + return PyBobLearnMiscISVTrainer_init_copy(self, args, kwargs); + } + case 2:{ + // If the constructor input is ISVTrainer object + return PyBobLearnMiscISVTrainer_init_number(self, args, kwargs); + } + default:{ + PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires only 0, 1 or 2 arguments, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs); + ISVTrainer_doc.print_usage(); + return -1; + } + } + BOB_CATCH_MEMBER("cannot create ISVTrainer", 0) + return 0; +} + + +static void PyBobLearnMiscISVTrainer_delete(PyBobLearnMiscISVTrainerObject* self) { + self->cxx.reset(); + Py_TYPE(self)->tp_free((PyObject*)self); +} + + +int PyBobLearnMiscISVTrainer_Check(PyObject* o) { + return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscISVTrainer_Type)); +} + + +static PyObject* PyBobLearnMiscISVTrainer_RichCompare(PyBobLearnMiscISVTrainerObject* self, PyObject* other, int op) { + BOB_TRY + + if (!PyBobLearnMiscISVTrainer_Check(other)) { + PyErr_Format(PyExc_TypeError, "cannot compare `%s' with `%s'", Py_TYPE(self)->tp_name, Py_TYPE(other)->tp_name); + return 0; + } + auto other_ = reinterpret_cast<PyBobLearnMiscISVTrainerObject*>(other); + switch (op) { + case Py_EQ: + if (*self->cxx==*other_->cxx) Py_RETURN_TRUE; else Py_RETURN_FALSE; + case Py_NE: + if (*self->cxx==*other_->cxx) Py_RETURN_FALSE; else Py_RETURN_TRUE; + default: + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + BOB_CATCH_MEMBER("cannot compare ISVTrainer objects", 0) +} + + +/******************************************************************/ +/************ Variables Section ***********************************/ +/******************************************************************/ + +static auto acc_u_a1 = bob::extension::VariableDoc( + "acc_u_a1", + "array_like <float, 3D>", + "Accumulator updated during the E-step", + "" +); +PyObject* PyBobLearnMiscISVTrainer_get_acc_u_a1(PyBobLearnMiscISVTrainerObject* self, void*){ + BOB_TRY + return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getAccUA1()); + BOB_CATCH_MEMBER("acc_u_a1 could not be read", 0) +} +int PyBobLearnMiscISVTrainer_set_acc_u_a1(PyBobLearnMiscISVTrainerObject* self, PyObject* value, void*){ + BOB_TRY + PyBlitzArrayObject* o; + if (!PyBlitzArray_Converter(value, &o)){ + PyErr_Format(PyExc_RuntimeError, "%s %s expects a 3D array of floats", Py_TYPE(self)->tp_name, acc_u_a1.name()); + return -1; + } + auto o_ = make_safe(o); + auto b = PyBlitzArrayCxx_AsBlitz<double,3>(o, "acc_u_a1"); + if (!b) return -1; + self->cxx->setAccUA1(*b); + return 0; + BOB_CATCH_MEMBER("acc_u_a1 could not be set", -1) +} + + +static auto acc_u_a2 = bob::extension::VariableDoc( + "acc_u_a2", + "array_like <float, 2D>", + "Accumulator updated during the E-step", + "" +); +PyObject* PyBobLearnMiscISVTrainer_get_acc_u_a2(PyBobLearnMiscISVTrainerObject* self, void*){ + BOB_TRY + return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getAccUA2()); + BOB_CATCH_MEMBER("acc_u_a2 could not be read", 0) +} +int PyBobLearnMiscISVTrainer_set_acc_u_a2(PyBobLearnMiscISVTrainerObject* self, PyObject* value, void*){ + BOB_TRY + PyBlitzArrayObject* o; + if (!PyBlitzArray_Converter(value, &o)){ + PyErr_Format(PyExc_RuntimeError, "%s %s expects a 2D array of floats", Py_TYPE(self)->tp_name, acc_u_a2.name()); + return -1; + } + auto o_ = make_safe(o); + auto b = PyBlitzArrayCxx_AsBlitz<double,2>(o, "acc_u_a2"); + if (!b) return -1; + self->cxx->setAccUA2(*b); + return 0; + BOB_CATCH_MEMBER("acc_u_a2 could not be set", -1) +} + + + + + +static auto __X__ = bob::extension::VariableDoc( + "__X__", + "list", + "", + "" +); +PyObject* PyBobLearnMiscISVTrainer_get_X(PyBobLearnMiscISVTrainerObject* self, void*){ + BOB_TRY + return vector_as_list(self->cxx->getX()); + BOB_CATCH_MEMBER("__X__ could not be read", 0) +} +int PyBobLearnMiscISVTrainer_set_X(PyBobLearnMiscISVTrainerObject* self, PyObject* value, void*){ + BOB_TRY + + // Parses input arguments in a single shot + if (!PyList_Check(value)){ + PyErr_Format(PyExc_TypeError, "Expected a list in `%s'", __X__.name()); + return -1; + } + + std::vector<blitz::Array<double,2> > data; + if(list_as_vector(value ,data)==0){ + self->cxx->setX(data); + } + + return 0; + BOB_CATCH_MEMBER("__X__ could not be written", 0) +} + + +static auto __Z__ = bob::extension::VariableDoc( + "__Z__", + "list", + "", + "" +); +PyObject* PyBobLearnMiscISVTrainer_get_Z(PyBobLearnMiscISVTrainerObject* self, void*){ + BOB_TRY + return vector_as_list(self->cxx->getZ()); + BOB_CATCH_MEMBER("__Z__ could not be read", 0) +} +int PyBobLearnMiscISVTrainer_set_Z(PyBobLearnMiscISVTrainerObject* self, PyObject* value, void*){ + BOB_TRY + + // Parses input arguments in a single shot + if (!PyList_Check(value)){ + PyErr_Format(PyExc_TypeError, "Expected a list in `%s'", __Z__.name()); + return -1; + } + + std::vector<blitz::Array<double,1> > data; + if(list_as_vector(value ,data)==0){ + self->cxx->setZ(data); + } + + return 0; + BOB_CATCH_MEMBER("__Z__ could not be written", 0) +} + + + + +static PyGetSetDef PyBobLearnMiscISVTrainer_getseters[] = { + { + acc_u_a1.name(), + (getter)PyBobLearnMiscISVTrainer_get_acc_u_a1, + (setter)PyBobLearnMiscISVTrainer_get_acc_u_a1, + acc_u_a1.doc(), + 0 + }, + { + acc_u_a2.name(), + (getter)PyBobLearnMiscISVTrainer_get_acc_u_a2, + (setter)PyBobLearnMiscISVTrainer_get_acc_u_a2, + acc_u_a2.doc(), + 0 + }, + { + __X__.name(), + (getter)PyBobLearnMiscISVTrainer_get_X, + (setter)PyBobLearnMiscISVTrainer_set_X, + __X__.doc(), + 0 + }, + { + __Z__.name(), + (getter)PyBobLearnMiscISVTrainer_get_Z, + (setter)PyBobLearnMiscISVTrainer_set_Z, + __Z__.doc(), + 0 + }, + + + + {0} // Sentinel +}; + + +/******************************************************************/ +/************ Functions Section ***********************************/ +/******************************************************************/ + +/*** initialize ***/ +static auto initialize = bob::extension::FunctionDoc( + "initialize", + "Initialization before the EM steps", + "", + true +) +.add_prototype("isv_base,stats") +.add_parameter("isv_base", ":py:class:`bob.learn.misc.ISVBase`", "ISVBase Object") +.add_parameter("stats", ":py:class:`bob.learn.misc.GMMStats`", "GMMStats Object"); +static PyObject* PyBobLearnMiscISVTrainer_initialize(PyBobLearnMiscISVTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = initialize.kwlist(0); + + PyBobLearnMiscISVBaseObject* isv_base = 0; + PyObject* stats = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!", kwlist, &PyBobLearnMiscISVBase_Type, &isv_base, + &PyList_Type, &stats)) Py_RETURN_NONE; + + std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data; + if(extract_GMMStats_2d(stats ,training_data)==0) + self->cxx->initialize(*isv_base->cxx, training_data); + + BOB_CATCH_MEMBER("cannot perform the initialize method", 0) + + Py_RETURN_NONE; +} + + +/*** e_step ***/ +static auto e_step = bob::extension::FunctionDoc( + "e_step", + "Call the e-step procedure (for the U subspace).", + "", + true +) +.add_prototype("isv_base,stats") +.add_parameter("isv_base", ":py:class:`bob.learn.misc.ISVBase`", "ISVBase Object") +.add_parameter("stats", ":py:class:`bob.learn.misc.GMMStats`", "GMMStats Object"); +static PyObject* PyBobLearnMiscISVTrainer_e_step(PyBobLearnMiscISVTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + // Parses input arguments in a single shot + char** kwlist = e_step.kwlist(0); + + PyBobLearnMiscISVBaseObject* isv_base = 0; + PyObject* stats = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!", kwlist, &PyBobLearnMiscISVBase_Type, &isv_base, + &PyList_Type, &stats)) Py_RETURN_NONE; + + std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data; + if(extract_GMMStats_2d(stats ,training_data)==0) + self->cxx->eStep(*isv_base->cxx, training_data); + + BOB_CATCH_MEMBER("cannot perform the e_step method", 0) + + Py_RETURN_NONE; +} + + +/*** m_step ***/ +static auto m_step = bob::extension::FunctionDoc( + "m_step", + "Call the m-step procedure (for the U subspace).", + "", + true +) +.add_prototype("isv_base,stats") +.add_parameter("isv_base", ":py:class:`bob.learn.misc.ISVBase`", "ISVBase Object"); +static PyObject* PyBobLearnMiscISVTrainer_m_step(PyBobLearnMiscISVTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + // Parses input arguments in a single shot + char** kwlist = m_step.kwlist(0); + + PyBobLearnMiscISVBaseObject* isv_base = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscISVBase_Type, &isv_base)) Py_RETURN_NONE; + + self->cxx->mStep(*isv_base->cxx); + + BOB_CATCH_MEMBER("cannot perform the m_step method", 0) + + Py_RETURN_NONE; +} + + + +/*** enrol ***/ +static auto enrol = bob::extension::FunctionDoc( + "enrol", + "", + "", + true +) +.add_prototype("isv_machine,features,n_iter","") +.add_parameter("isv_machine", ":py:class:`bob.learn.misc.ISVMachine`", "ISVMachine Object") +.add_parameter("features", "list(:py:class:`bob.learn.misc.GMMStats`)`", "") +.add_parameter("n_iter", "int", "Number of iterations"); +static PyObject* PyBobLearnMiscISVTrainer_enrol(PyBobLearnMiscISVTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + // Parses input arguments in a single shot + char** kwlist = enrol.kwlist(0); + + PyBobLearnMiscISVMachineObject* isv_machine = 0; + PyObject* stats = 0; + int n_iter = 1; + + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!i", kwlist, &PyBobLearnMiscISVMachine_Type, &isv_machine, + &PyList_Type, &stats, &n_iter)) Py_RETURN_NONE; + + std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > training_data; + if(extract_GMMStats_1d(stats ,training_data)==0) + self->cxx->enrol(*isv_machine->cxx, training_data, n_iter); + + BOB_CATCH_MEMBER("cannot perform the enrol method", 0) + + Py_RETURN_NONE; +} + + + +static PyMethodDef PyBobLearnMiscISVTrainer_methods[] = { + { + initialize.name(), + (PyCFunction)PyBobLearnMiscISVTrainer_initialize, + METH_VARARGS|METH_KEYWORDS, + initialize.doc() + }, + { + e_step.name(), + (PyCFunction)PyBobLearnMiscISVTrainer_e_step, + METH_VARARGS|METH_KEYWORDS, + e_step.doc() + }, + { + m_step.name(), + (PyCFunction)PyBobLearnMiscISVTrainer_m_step, + METH_VARARGS|METH_KEYWORDS, + m_step.doc() + }, + { + enrol.name(), + (PyCFunction)PyBobLearnMiscISVTrainer_enrol, + METH_VARARGS|METH_KEYWORDS, + enrol.doc() + }, + {0} /* Sentinel */ +}; + + +/******************************************************************/ +/************ Module Section **************************************/ +/******************************************************************/ + +// Define the Gaussian type struct; will be initialized later +PyTypeObject PyBobLearnMiscISVTrainer_Type = { + PyVarObject_HEAD_INIT(0,0) + 0 +}; + +bool init_BobLearnMiscISVTrainer(PyObject* module) +{ + // initialize the type struct + PyBobLearnMiscISVTrainer_Type.tp_name = ISVTrainer_doc.name(); + PyBobLearnMiscISVTrainer_Type.tp_basicsize = sizeof(PyBobLearnMiscISVTrainerObject); + PyBobLearnMiscISVTrainer_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;//Enable the class inheritance; + PyBobLearnMiscISVTrainer_Type.tp_doc = ISVTrainer_doc.doc(); + + // set the functions + PyBobLearnMiscISVTrainer_Type.tp_new = PyType_GenericNew; + PyBobLearnMiscISVTrainer_Type.tp_init = reinterpret_cast<initproc>(PyBobLearnMiscISVTrainer_init); + PyBobLearnMiscISVTrainer_Type.tp_dealloc = reinterpret_cast<destructor>(PyBobLearnMiscISVTrainer_delete); + PyBobLearnMiscISVTrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscISVTrainer_RichCompare); + PyBobLearnMiscISVTrainer_Type.tp_methods = PyBobLearnMiscISVTrainer_methods; + PyBobLearnMiscISVTrainer_Type.tp_getset = PyBobLearnMiscISVTrainer_getseters; + //PyBobLearnMiscISVTrainer_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscISVTrainer_compute_likelihood); + + + // check that everything is fine + if (PyType_Ready(&PyBobLearnMiscISVTrainer_Type) < 0) return false; + + // add the type to the module + Py_INCREF(&PyBobLearnMiscISVTrainer_Type); + return PyModule_AddObject(module, "_ISVTrainer", (PyObject*)&PyBobLearnMiscISVTrainer_Type) >= 0; +} + diff --git a/bob/learn/misc/jfa_trainer.cpp b/bob/learn/misc/jfa_trainer.cpp index 8c838a872edf8fa5a1d4c00185917d76e46edf01..8d46e94f35e6311847b795e0380b89a4af75d124 100644 --- a/bob/learn/misc/jfa_trainer.cpp +++ b/bob/learn/misc/jfa_trainer.cpp @@ -987,7 +987,7 @@ PyTypeObject PyBobLearnMiscJFATrainer_Type = { bool init_BobLearnMiscJFATrainer(PyObject* module) { - // initialize the type struct + // initialize the type JFATrainer PyBobLearnMiscJFATrainer_Type.tp_name = JFATrainer_doc.name(); PyBobLearnMiscJFATrainer_Type.tp_basicsize = sizeof(PyBobLearnMiscJFATrainerObject); PyBobLearnMiscJFATrainer_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;//Enable the class inheritance; diff --git a/bob/learn/misc/main.cpp b/bob/learn/misc/main.cpp index b5748880cc8a679c43e191054e9c4680b97f13f7..54a5d486de04af187877354d6e3a958f3d596a79 100644 --- a/bob/learn/misc/main.cpp +++ b/bob/learn/misc/main.cpp @@ -72,8 +72,10 @@ static PyObject* create_module (void) { if (!init_BobLearnMiscJFABase(module)) return 0; if (!init_BobLearnMiscJFAMachine(module)) return 0; if (!init_BobLearnMiscJFATrainer(module)) return 0; + if (!init_BobLearnMiscISVBase(module)) return 0; if (!init_BobLearnMiscISVMachine(module)) return 0; + if (!init_BobLearnMiscISVTrainer(module)) return 0; if (!init_BobLearnMiscIVectorMachine(module)) return 0; if (!init_BobLearnMiscPLDABase(module)) return 0; diff --git a/bob/learn/misc/main.h b/bob/learn/misc/main.h index d6049debc9cdac6538b7c3e579bf784b23047669..b3a4fb50531de118a146c14769e6157b153936e8 100644 --- a/bob/learn/misc/main.h +++ b/bob/learn/misc/main.h @@ -30,8 +30,10 @@ #include <bob.learn.misc/JFABase.h> #include <bob.learn.misc/JFAMachine.h> #include <bob.learn.misc/JFATrainer.h> + #include <bob.learn.misc/ISVBase.h> #include <bob.learn.misc/ISVMachine.h> +#include <bob.learn.misc/ISVTrainer.h> #include <bob.learn.misc/IVectorMachine.h> @@ -221,6 +223,15 @@ extern PyTypeObject PyBobLearnMiscISVMachine_Type; bool init_BobLearnMiscISVMachine(PyObject* module); int PyBobLearnMiscISVMachine_Check(PyObject* o); +// ISVTrainer +typedef struct { + PyObject_HEAD + boost::shared_ptr<bob::learn::misc::ISVTrainer> cxx; +} PyBobLearnMiscISVTrainerObject; + +extern PyTypeObject PyBobLearnMiscISVTrainer_Type; +bool init_BobLearnMiscISVTrainer(PyObject* module); +int PyBobLearnMiscISVTrainer_Check(PyObject* o); // IVectorMachine typedef struct { diff --git a/bob/learn/misc/test_jfa_trainer.py b/bob/learn/misc/test_jfa_trainer.py index c094a6da05de832d31440886346f2d67a4963d65..c2505b2b61b67faae63398cef3f7e45950364400 100644 --- a/bob/learn/misc/test_jfa_trainer.py +++ b/bob/learn/misc/test_jfa_trainer.py @@ -13,8 +13,7 @@ import numpy.linalg import bob.core.random -from . import GMMStats, GMMMachine, JFABase, JFAMachine, ISVBase, ISVMachine, JFATrainer -#, ISVTrainer +from . import GMMStats, GMMMachine, JFABase, JFAMachine, ISVBase, ISVMachine, JFATrainer, ISVTrainer def equals(x, y, epsilon): @@ -233,7 +232,6 @@ def test_ISVTrainAndEnrol(): for i in range(10): t.e_step(mb, TRAINING_STATS) t.m_step(mb, TRAINING_STATS) - t.finalize(mb, TRAINING_STATS) assert numpy.allclose(mb.d, d_ref, eps) assert numpy.allclose(mb.u, u_ref, eps) diff --git a/setup.py b/setup.py index f0d5bcfc0bc9980cb8d13fdc89720e14dbb12ca4..088252fe4b66182f8c24adbed78cb0cee10f3308 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,7 @@ setup( "bob/learn/misc/cpp/FABaseTrainer.cpp", "bob/learn/misc/cpp/JFATrainer.cpp", - #"bob/learn/misc/cpp/ISVTrainer.cpp", + "bob/learn/misc/cpp/ISVTrainer.cpp", #"bob/learn/misc/cpp/EMPCATrainer.cpp", "bob/learn/misc/cpp/GMMBaseTrainer.cpp", @@ -120,8 +120,10 @@ setup( "bob/learn/misc/jfa_base.cpp", "bob/learn/misc/jfa_machine.cpp", "bob/learn/misc/jfa_trainer.cpp", + "bob/learn/misc/isv_base.cpp", "bob/learn/misc/isv_machine.cpp", + "bob/learn/misc/isv_trainer.cpp", "bob/learn/misc/ivector_machine.cpp", "bob/learn/misc/plda_base.cpp",