Serialized ISVMachine

parent 7a1dfc76
Pipeline #44911 failed with stage
in 2 minutes and 49 seconds
......@@ -12,6 +12,7 @@ bob.extension.load_bob_library("bob.learn.em", __file__)
from ._library import *
from ._library import GMMMachine as _GMMMachine_C
from ._library import ISVBase as _ISVBase_C
from ._library import ISVMachine as _ISVMachine_C
from . import version
from .version import module as __version__
......@@ -90,7 +91,7 @@ class ISVBase(_ISVBase_C):
@staticmethod
def to_dict(isv_base):
isv_data = dict()
isv_data["gmm"] = GMMMachine.to_dict(isv_base.ubm)
isv_data["gmm"] = GMMMachine.to_dict(isv_base.ubm)
isv_data["u"] = isv_base.u
isv_data["d"] = isv_base.d
......@@ -103,12 +104,48 @@ class ISVBase(_ISVBase_C):
self.u = u
self.d = d["d"]
@classmethod
def create_from_dict(cls, d):
ubm = GMMMachine.create_from_dict(d["gmm"])
ru = d["u"].shape[1]
isv_base = ISVBase(ubm, ru)
isv_base.u = d["u"]
isv_base.d = d["d"]
return isv_base
def __getstate__(self):
d = dict(self.__dict__)
d.update(self.__class__.to_dict(self))
return d
def __setstate__(self, d):
def __setstate__(self, d):
self.__dict__ = d
self.update_dict(d)
class ISVMachine(_ISVMachine_C):
__doc__ = _ISVMachine_C.__doc__
@staticmethod
def to_dict(isv_machine):
isv_data = dict()
isv_data["x"] = isv_machine.x
isv_data["z"] = isv_machine.z
isv_data["isv_base"] = ISVBase.to_dict(isv_machine.isv_base)
return isv_data
def update_dict(self, d):
isv_base = ISVBase.create_from_dict(d["isv_base"])
self.__init__(isv_base)
self.x = d["x"]
self.z = d["z"]
def __getstate__(self):
d = dict(self.__dict__)
d.update(self.__class__.to_dict(self))
return d
def __setstate__(self, d):
self.__dict__ = d
self.update_dict(d)
......@@ -108,6 +108,19 @@ void bob::learn::em::ISVMachine::setZ(const blitz::Array<double,1>& z)
updateCache();
}
void bob::learn::em::ISVMachine::setX(const blitz::Array<double,1>& x)
{
if(x.extent(0) != m_cache_x.extent(0)) { //checks dimension
boost::format m("size of input vector `x' (%d) does not match the expected size (%d)");
m % x.extent(0) % m_cache_x.extent(0);
throw std::runtime_error(m.str());
}
m_cache_x.reference(bob::core::array::ccopy(x));
// update cache
updateCache();
}
void bob::learn::em::ISVMachine::setISVBase(const boost::shared_ptr<bob::learn::em::ISVBase> isv_base)
{
if (!isv_base->getUbm())
......
......@@ -147,6 +147,13 @@ class ISVMachine
*/
void setZ(const blitz::Array<double,1>& z);
/**
* @brief Sets the session variable
*/
void setX(const blitz::Array<double,1>& x);
/**
* @brief Returns the ISVBase
*/
......
......@@ -240,6 +240,37 @@ PyObject* PyBobLearnEMISVMachine_getX(PyBobLearnEMISVMachineObject* self, void*)
return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getX());
BOB_CATCH_MEMBER("`x` could not be read", 0)
}
int PyBobLearnEMISVMachine_setX(PyBobLearnEMISVMachineObject* self, PyObject* value, void*){
BOB_TRY
PyBlitzArrayObject* input;
if (!PyBlitzArray_Converter(value, &input)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects a 1D array of floats", Py_TYPE(self)->tp_name, X.name());
return -1;
}
auto o_ = make_safe(input);
// perform check on the input
if (input->type_num != NPY_FLOAT64){
PyErr_Format(PyExc_TypeError, "`%s' only supports 64-bit float arrays for input array `%s`", Py_TYPE(self)->tp_name, X.name());
return -1;
}
if (input->ndim != 1){
PyErr_Format(PyExc_TypeError, "`%s' only processes 1D arrays of float64 for `%s`", Py_TYPE(self)->tp_name, X.name());
return -1;
}
if (input->shape[0] != (Py_ssize_t)self->cxx->getX().extent(0)) {
PyErr_Format(PyExc_TypeError, "`%s' 1D `input` array should have %" PY_FORMAT_SIZE_T "d, elements, not %" PY_FORMAT_SIZE_T "d for `%s`", Py_TYPE(self)->tp_name, (Py_ssize_t)self->cxx->getX().extent(0), (Py_ssize_t)input->shape[0], X.name());
return -1;
}
auto b = PyBlitzArrayCxx_AsBlitz<double,1>(input, "x");
if (!b) return -1;
self->cxx->setX(*b);
return 0;
BOB_CATCH_MEMBER("`x` vector could not be set", -1)
}
/***** isv_base *****/
......@@ -318,7 +349,7 @@ static PyGetSetDef PyBobLearnEMISVMachine_getseters[] = {
{
X.name(),
(getter)PyBobLearnEMISVMachine_getX,
0,
(setter)PyBobLearnEMISVMachine_setX,
X.doc(),
0
},
......
......@@ -2,7 +2,7 @@
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
from bob.learn.em import GMMMachine, ISVBase
from bob.learn.em import GMMMachine, ISVBase, ISVMachine
import numpy
import pickle
......@@ -17,7 +17,7 @@ def test_gmm_machine():
assert numpy.allclose(gmm_machine_after_pickle.weights, gmm_machine_after_pickle.weights, 10e-3)
def test_isv():
def test_isv_base():
ubm = GMMMachine(3,3)
ubm.means = numpy.arange(9).reshape(3,3).astype("float")
isv_base = ISVBase(ubm, 2)
......@@ -27,4 +27,37 @@ def test_isv():
isv_base_after_pickle = pickle.loads(pickle.dumps(isv_base))
assert numpy.allclose(isv_base.u, isv_base_after_pickle.u, 10e-3)
assert numpy.allclose(isv_base.d, isv_base_after_pickle.d, 10e-3)
\ No newline at end of file
assert numpy.allclose(isv_base.d, isv_base_after_pickle.d, 10e-3)
def test_isv_machine():
# Creates a UBM
weights = numpy.array([0.4, 0.6], 'float64')
means = numpy.array([[1, 6, 2], [4, 3, 2]], 'float64')
variances = numpy.array([[1, 2, 1], [2, 1, 2]], 'float64')
ubm = GMMMachine(2,3)
ubm.weights = weights
ubm.means = means
ubm.variances = variances
# Creates a ISVBaseMachine
U = numpy.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], 'float64')
#V = numpy.array([[0], [0], [0], [0], [0], [0]], 'float64')
d = numpy.array([0, 1, 0, 1, 0, 1], 'float64')
base = ISVBase(ubm,2)
base.u = U
base.d = d
# Creates a ISVMachine
z = numpy.array([3,4,1,2,0,1], 'float64')
x = numpy.array([1,2], 'float64')
isv_machine = ISVMachine(base)
isv_machine.z = z
isv_machine.x = x
isv_machine_after_pickle = pickle.loads(pickle.dumps(isv_machine))
assert numpy.allclose(isv_machine_after_pickle.isv_base.u, isv_machine.isv_base.u, 10e-3)
assert numpy.allclose(isv_machine_after_pickle.isv_base.d, isv_machine.isv_base.d, 10e-3)
assert numpy.allclose(isv_machine_after_pickle.x, isv_machine.x, 10e-3)
assert numpy.allclose(isv_machine_after_pickle.z, isv_machine.z, 10e-3)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment