Making GMMMachine picklable

parent 06cbbae4
Pipeline #39191 failed with stage
in 8 minutes
......@@ -9,6 +9,8 @@ import bob.extension
bob.extension.load_bob_library('bob.learn.em', __file__)
from ._library import *
from ._library import GMMMachine as _GMMMachine_C
from . import version
from .version import module as __version__
from .version import api as __api_version__
......@@ -38,3 +40,22 @@ def get_config():
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
class GMMMachine(_GMMMachine_C):
__doc__ = _GMMMachine_C.__doc__
def __getstate__(self):
d = dict(self.__dict__)
d["means"] = self.means
d["variances"] = self.variances
d["weights"] = self.weights
return d
def __setstate__(self, d):
means = d["means"]
self.__init__(means.shape[0], means.shape[1])
self.means = means
self.variances = d["variances"]
self.means = d["means"]
self.__dict__ = d
......@@ -966,7 +966,7 @@ bool init_BobLearnEMGMMMachine(PyObject* module)
// initialize the type struct
PyBobLearnEMGMMMachine_Type.tp_name = GMMMachine_doc.name();
PyBobLearnEMGMMMachine_Type.tp_basicsize = sizeof(PyBobLearnEMGMMMachineObject);
PyBobLearnEMGMMMachine_Type.tp_flags = Py_TPFLAGS_DEFAULT;
PyBobLearnEMGMMMachine_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
PyBobLearnEMGMMMachine_Type.tp_doc = GMMMachine_doc.doc();
// set the functions
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>import numpy
from bob.pipelines.utils import assert_picklable
from bob.learn.em import GMMMachine
from .test_em import equals
import numpy
import pickle
def test_gmm_machine():
gmm_machine = GMMMachine(3,3)
gmm_machine.means = numpy.arange(9).reshape(3,3).astype("float")
gmm_machine_after_pickle = pickle.loads(pickle.dumps(gmm_machine))
assert equals(gmm_machine_after_pickle.means, gmm_machine_after_pickle.means, 10e-3)
assert equals(gmm_machine_after_pickle.variances, gmm_machine_after_pickle.variances, 10e-3)
assert equals(gmm_machine_after_pickle.weights, gmm_machine_after_pickle.weights, 10e-3)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment