diff --git a/bob/learn/em/__init__.py b/bob/learn/em/__init__.py index f1edef0e24e3f7d39b4d9ba067dff187ed665399..eb7df695bfbaacc3052acbd98e4fb1deae19b0c8 100644 --- a/bob/learn/em/__init__.py +++ b/bob/learn/em/__init__.py @@ -12,6 +12,7 @@ bob.extension.load_bob_library("bob.learn.em", __file__) from ._library import * from ._library import GMMMachine as _GMMMachine_C from ._library import ISVBase as _ISVBase_C +from ._library import ISVMachine as _ISVMachine_C from . import version from .version import module as __version__ @@ -90,7 +91,7 @@ class ISVBase(_ISVBase_C): @staticmethod def to_dict(isv_base): isv_data = dict() - isv_data["gmm"] = GMMMachine.to_dict(isv_base.ubm) + isv_data["gmm"] = GMMMachine.to_dict(isv_base.ubm) isv_data["u"] = isv_base.u isv_data["d"] = isv_base.d @@ -103,12 +104,48 @@ class ISVBase(_ISVBase_C): self.u = u self.d = d["d"] + @classmethod + def create_from_dict(cls, d): + ubm = GMMMachine.create_from_dict(d["gmm"]) + ru = d["u"].shape[1] + isv_base = ISVBase(ubm, ru) + isv_base.u = d["u"] + isv_base.d = d["d"] + return isv_base + def __getstate__(self): d = dict(self.__dict__) d.update(self.__class__.to_dict(self)) return d - def __setstate__(self, d): + def __setstate__(self, d): self.__dict__ = d self.update_dict(d) + +class ISVMachine(_ISVMachine_C): + __doc__ = _ISVMachine_C.__doc__ + + @staticmethod + def to_dict(isv_machine): + isv_data = dict() + isv_data["x"] = isv_machine.x + isv_data["z"] = isv_machine.z + isv_data["isv_base"] = ISVBase.to_dict(isv_machine.isv_base) + + return isv_data + + def update_dict(self, d): + isv_base = ISVBase.create_from_dict(d["isv_base"]) + self.__init__(isv_base) + self.x = d["x"] + self.z = d["z"] + + def __getstate__(self): + d = dict(self.__dict__) + d.update(self.__class__.to_dict(self)) + return d + + def __setstate__(self, d): + self.__dict__ = d + self.update_dict(d) diff --git a/bob/learn/em/cpp/ISVMachine.cpp b/bob/learn/em/cpp/ISVMachine.cpp index 894126d2cd20360feeb3418459ecb8879f0aa4e7..578e5d29f61f04af197c0b27aa27f32bb70e72b7 100644 --- a/bob/learn/em/cpp/ISVMachine.cpp +++ b/bob/learn/em/cpp/ISVMachine.cpp @@ -108,6 +108,19 @@ void bob::learn::em::ISVMachine::setZ(const blitz::Array<double,1>& z) updateCache(); } + +void bob::learn::em::ISVMachine::setX(const blitz::Array<double,1>& x) +{ + if(x.extent(0) != m_cache_x.extent(0)) { //checks dimension + boost::format m("size of input vector `x' (%d) does not match the expected size (%d)"); + m % x.extent(0) % m_cache_x.extent(0); + throw std::runtime_error(m.str()); + } + m_cache_x.reference(bob::core::array::ccopy(x)); + // update cache + updateCache(); +} + void bob::learn::em::ISVMachine::setISVBase(const boost::shared_ptr<bob::learn::em::ISVBase> isv_base) { if (!isv_base->getUbm()) diff --git a/bob/learn/em/include/bob.learn.em/ISVMachine.h b/bob/learn/em/include/bob.learn.em/ISVMachine.h index e4a5cc9c434809a882843ec7606fbb871d689b4a..f51f4dbf008638e91cd3cc850ba0f9f9e6cff172 100644 --- a/bob/learn/em/include/bob.learn.em/ISVMachine.h +++ b/bob/learn/em/include/bob.learn.em/ISVMachine.h @@ -147,6 +147,13 @@ class ISVMachine */ void setZ(const blitz::Array<double,1>& z); + + /** + * @brief Sets the session variable + */ + void setX(const blitz::Array<double,1>& x); + + /** * @brief Returns the ISVBase */ diff --git a/bob/learn/em/isv_machine.cpp b/bob/learn/em/isv_machine.cpp index 4a860a5d009e71de790efa260fc473415c441560..097467a553d36b617fe1be644e3fddc04ba52bbf 100644 --- a/bob/learn/em/isv_machine.cpp +++ b/bob/learn/em/isv_machine.cpp @@ -240,6 +240,37 @@ PyObject* PyBobLearnEMISVMachine_getX(PyBobLearnEMISVMachineObject* self, void*) return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getX()); BOB_CATCH_MEMBER("`x` could not be read", 0) } +int PyBobLearnEMISVMachine_setX(PyBobLearnEMISVMachineObject* self, PyObject* value, void*){ + BOB_TRY + PyBlitzArrayObject* input; + if (!PyBlitzArray_Converter(value, &input)){ + PyErr_Format(PyExc_RuntimeError, "%s %s expects a 1D array of floats", Py_TYPE(self)->tp_name, X.name()); + return -1; + } + auto o_ = make_safe(input); + + // perform check on the input + if (input->type_num != NPY_FLOAT64){ + PyErr_Format(PyExc_TypeError, "`%s' only supports 64-bit float arrays for input array `%s`", Py_TYPE(self)->tp_name, X.name()); + return -1; + } + + if (input->ndim != 1){ + PyErr_Format(PyExc_TypeError, "`%s' only processes 1D arrays of float64 for `%s`", Py_TYPE(self)->tp_name, X.name()); + return -1; + } + + if (input->shape[0] != (Py_ssize_t)self->cxx->getX().extent(0)) { + PyErr_Format(PyExc_TypeError, "`%s' 1D `input` array should have %" PY_FORMAT_SIZE_T "d, elements, not %" PY_FORMAT_SIZE_T "d for `%s`", Py_TYPE(self)->tp_name, (Py_ssize_t)self->cxx->getX().extent(0), (Py_ssize_t)input->shape[0], X.name()); + return -1; + } + + auto b = PyBlitzArrayCxx_AsBlitz<double,1>(input, "x"); + if (!b) return -1; + self->cxx->setX(*b); + return 0; + BOB_CATCH_MEMBER("`x` vector could not be set", -1) +} /***** isv_base *****/ @@ -318,7 +349,7 @@ static PyGetSetDef PyBobLearnEMISVMachine_getseters[] = { { X.name(), (getter)PyBobLearnEMISVMachine_getX, - 0, + (setter)PyBobLearnEMISVMachine_setX, X.doc(), 0 }, diff --git a/bob/learn/em/test/test_picklability.py b/bob/learn/em/test/test_picklability.py index dba70d55fd10d9c32d2fde645bdc5d0cbdd55e89..bb8a91b31655ba275169a1e16f7d4e5198e14c23 100644 --- a/bob/learn/em/test/test_picklability.py +++ b/bob/learn/em/test/test_picklability.py @@ -2,7 +2,7 @@ # vim: set fileencoding=utf-8 : # Tiago de Freitas Pereira <tiago.pereira@idiap.ch> -from bob.learn.em import GMMMachine, ISVBase +from bob.learn.em import GMMMachine, ISVBase, ISVMachine import numpy import pickle @@ -17,7 +17,7 @@ def test_gmm_machine(): assert numpy.allclose(gmm_machine_after_pickle.weights, gmm_machine_after_pickle.weights, 10e-3) -def test_isv(): +def test_isv_base(): ubm = GMMMachine(3,3) ubm.means = numpy.arange(9).reshape(3,3).astype("float") isv_base = ISVBase(ubm, 2) @@ -27,4 +27,37 @@ def test_isv(): isv_base_after_pickle = pickle.loads(pickle.dumps(isv_base)) assert numpy.allclose(isv_base.u, isv_base_after_pickle.u, 10e-3) - assert numpy.allclose(isv_base.d, isv_base_after_pickle.d, 10e-3) \ No newline at end of file + assert numpy.allclose(isv_base.d, isv_base_after_pickle.d, 10e-3) + + +def test_isv_machine(): + + # Creates a UBM + weights = numpy.array([0.4, 0.6], 'float64') + means = numpy.array([[1, 6, 2], [4, 3, 2]], 'float64') + variances = numpy.array([[1, 2, 1], [2, 1, 2]], 'float64') + ubm = GMMMachine(2,3) + ubm.weights = weights + ubm.means = means + ubm.variances = variances + + # Creates a ISVBaseMachine + U = numpy.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], 'float64') + #V = numpy.array([[0], [0], [0], [0], [0], [0]], 'float64') + d = numpy.array([0, 1, 0, 1, 0, 1], 'float64') + base = ISVBase(ubm,2) + base.u = U + base.d = d + + # Creates a ISVMachine + z = numpy.array([3,4,1,2,0,1], 'float64') + x = numpy.array([1,2], 'float64') + isv_machine = ISVMachine(base) + isv_machine.z = z + isv_machine.x = x + + isv_machine_after_pickle = pickle.loads(pickle.dumps(isv_machine)) + assert numpy.allclose(isv_machine_after_pickle.isv_base.u, isv_machine.isv_base.u, 10e-3) + assert numpy.allclose(isv_machine_after_pickle.isv_base.d, isv_machine.isv_base.d, 10e-3) + assert numpy.allclose(isv_machine_after_pickle.x, isv_machine.x, 10e-3) + assert numpy.allclose(isv_machine_after_pickle.z, isv_machine.z, 10e-3)