Pickle GMMStats

parent fb8d501d
Pipeline #44956 failed with stage
in 36 minutes and 50 seconds
......@@ -14,6 +14,7 @@ from ._library import GMMMachine as _GMMMachine_C
from ._library import ISVBase as _ISVBase_C
from ._library import ISVMachine as _ISVMachine_C
from ._library import KMeansMachine as _KMeansMachine_C
from ._library import GMMStats as _GMMStats_C
from . import version
from .version import module as __version__
......@@ -153,3 +154,31 @@ class KMeansMachine(_KMeansMachine_C):
means = d["means"]
self.__init__(means.shape[0], means.shape[1])
self.means = means
class GMMStats(_GMMStats_C):
__doc__ = _GMMStats_C.__doc__
@staticmethod
def to_dict(gmm_stats):
gmm_stats_data = dict()
gmm_stats_data["log_likelihood"] = gmm_stats.log_likelihood
gmm_stats_data["t"] = gmm_stats.t
gmm_stats_data["n"] = gmm_stats.n
gmm_stats_data["sum_px"] = gmm_stats.sum_px
gmm_stats_data["sum_pxx"] = gmm_stats.sum_pxx
return gmm_stats_data
def __getstate__(self):
d = dict(self.__dict__)
d.update(self.__class__.to_dict(self))
return d
def __setstate__(self, d):
shape = d["sum_pxx"].shape
self.__init__(shape[0], shape[1])
self.t = d["t"]
self.n = d["n"]
self.log_likelihood = d["log_likelihood"]
self.sum_px = d["sum_px"]
self.sum_pxx = d["sum_pxx"]
......@@ -648,7 +648,7 @@ bool init_BobLearnEMGMMStats(PyObject* module)
// initialize the type struct
PyBobLearnEMGMMStats_Type.tp_name = GMMStats_doc.name();
PyBobLearnEMGMMStats_Type.tp_basicsize = sizeof(PyBobLearnEMGMMStatsObject);
PyBobLearnEMGMMStats_Type.tp_flags = Py_TPFLAGS_DEFAULT;
PyBobLearnEMGMMStats_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
PyBobLearnEMGMMStats_Type.tp_doc = GMMStats_doc.doc();
// set the functions
......
......@@ -2,7 +2,7 @@
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
from bob.learn.em import GMMMachine, ISVBase, ISVMachine, KMeansMachine
from bob.learn.em import GMMMachine, ISVBase, ISVMachine, KMeansMachine, GMMStats
import numpy
import pickle
......@@ -75,3 +75,21 @@ def test_kmeans_machine():
kmeans_machine_after_pickle = pickle.loads(pickle.dumps(kmeans_machine))
assert numpy.allclose(kmeans_machine_after_pickle.means, kmeans_machine.means, 10e-3)
def test_gmmstats():
gs = GMMStats(2,3)
log_likelihood = -3.
T = 1
n = numpy.array([0.4, 0.6], numpy.float64)
sumpx = numpy.array([[1., 2., 3.], [2., 4., 3.]], numpy.float64)
sumpxx = numpy.array([[10., 20., 30.], [40., 50., 60.]], numpy.float64)
gs.log_likelihood = log_likelihood
gs.t = T
gs.n = n
gs.sum_px = sumpx
gs.sum_pxx = sumpxx
gs_after_pickle = pickle.loads(pickle.dumps(gs))
assert gs == gs_after_pickle
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment