Skip to content
Snippets Groups Projects
Commit cc846902 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Pickle iVectorMachine

parent d6a9cbc2
No related branches found
No related tags found
1 merge request!39Pickling Objects
Pipeline #44961 failed
...@@ -15,12 +15,14 @@ from ._library import ISVBase as _ISVBase_C ...@@ -15,12 +15,14 @@ from ._library import ISVBase as _ISVBase_C
from ._library import ISVMachine as _ISVMachine_C from ._library import ISVMachine as _ISVMachine_C
from ._library import KMeansMachine as _KMeansMachine_C from ._library import KMeansMachine as _KMeansMachine_C
from ._library import GMMStats as _GMMStats_C from ._library import GMMStats as _GMMStats_C
from ._library import IVectorMachine as _IVectorMachine_C
from . import version from . import version
from .version import module as __version__ from .version import module as __version__
from .version import api as __api_version__ from .version import api as __api_version__
from .train import * from .train import *
def get_config(): def get_config():
"""Returns a string containing the configuration information. """Returns a string containing the configuration information.
""" """
...@@ -37,7 +39,7 @@ class GMMMachine(_GMMMachine_C): ...@@ -37,7 +39,7 @@ class GMMMachine(_GMMMachine_C):
def update_dict(self, d): def update_dict(self, d):
self.means = d["means"] self.means = d["means"]
self.variances = d["variances"] self.variances = d["variances"]
self.means = d["means"] self.weights = d["weights"]
@staticmethod @staticmethod
def gmm_shape_from_dict(d): def gmm_shape_from_dict(d):
...@@ -182,3 +184,32 @@ class GMMStats(_GMMStats_C): ...@@ -182,3 +184,32 @@ class GMMStats(_GMMStats_C):
self.log_likelihood = d["log_likelihood"] self.log_likelihood = d["log_likelihood"]
self.sum_px = d["sum_px"] self.sum_px = d["sum_px"]
self.sum_pxx = d["sum_pxx"] self.sum_pxx = d["sum_pxx"]
class IVectorMachine(_IVectorMachine_C):
__doc__ = _IVectorMachine_C.__doc__
@staticmethod
def to_dict(ivector_machine):
ivector_data = dict()
ivector_data["gmm"] = GMMMachine.to_dict(ivector_machine.ubm)
ivector_data["sigma"] = ivector_machine.sigma
ivector_data["t"] = ivector_machine.t
return ivector_data
def update_dict(self, d):
ubm = GMMMachine.create_from_dict(d["gmm"])
t = d["t"]
self.__init__(ubm, t.shape[1])
self.sigma = d["sigma"]
self.t = t
def __getstate__(self):
d = dict(self.__dict__)
d.update(self.__class__.to_dict(self))
return d
def __setstate__(self, d):
self.__dict__ = d
self.update_dict(d)
...@@ -651,7 +651,7 @@ bool init_BobLearnEMIVectorMachine(PyObject* module) ...@@ -651,7 +651,7 @@ bool init_BobLearnEMIVectorMachine(PyObject* module)
// initialize the type struct // initialize the type struct
PyBobLearnEMIVectorMachine_Type.tp_name = IVectorMachine_doc.name(); PyBobLearnEMIVectorMachine_Type.tp_name = IVectorMachine_doc.name();
PyBobLearnEMIVectorMachine_Type.tp_basicsize = sizeof(PyBobLearnEMIVectorMachineObject); PyBobLearnEMIVectorMachine_Type.tp_basicsize = sizeof(PyBobLearnEMIVectorMachineObject);
PyBobLearnEMIVectorMachine_Type.tp_flags = Py_TPFLAGS_DEFAULT; PyBobLearnEMIVectorMachine_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
PyBobLearnEMIVectorMachine_Type.tp_doc = IVectorMachine_doc.doc(); PyBobLearnEMIVectorMachine_Type.tp_doc = IVectorMachine_doc.doc();
// set the functions // set the functions
......
...@@ -2,26 +2,39 @@ ...@@ -2,26 +2,39 @@
# vim: set fileencoding=utf-8 : # vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch> # Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
from bob.learn.em import GMMMachine, ISVBase, ISVMachine, KMeansMachine, GMMStats from bob.learn.em import (
GMMMachine,
ISVBase,
ISVMachine,
KMeansMachine,
GMMStats,
IVectorMachine,
)
import numpy import numpy
import pickle import pickle
def test_gmm_machine(): def test_gmm_machine():
gmm_machine = GMMMachine(3,3) gmm_machine = GMMMachine(3, 3)
gmm_machine.means = numpy.arange(9).reshape(3,3).astype("float") gmm_machine.means = numpy.arange(9).reshape(3, 3).astype("float")
gmm_machine_after_pickle = pickle.loads(pickle.dumps(gmm_machine)) gmm_machine_after_pickle = pickle.loads(pickle.dumps(gmm_machine))
assert numpy.allclose(gmm_machine_after_pickle.means, gmm_machine_after_pickle.means, 10e-3) assert numpy.allclose(
assert numpy.allclose(gmm_machine_after_pickle.variances, gmm_machine_after_pickle.variances, 10e-3) gmm_machine_after_pickle.means, gmm_machine_after_pickle.means, 10e-3
assert numpy.allclose(gmm_machine_after_pickle.weights, gmm_machine_after_pickle.weights, 10e-3) )
assert numpy.allclose(
gmm_machine_after_pickle.variances, gmm_machine_after_pickle.variances, 10e-3
)
assert numpy.allclose(
gmm_machine_after_pickle.weights, gmm_machine_after_pickle.weights, 10e-3
)
def test_isv_base(): def test_isv_base():
ubm = GMMMachine(3,3) ubm = GMMMachine(3, 3)
ubm.means = numpy.arange(9).reshape(3,3).astype("float") ubm.means = numpy.arange(9).reshape(3, 3).astype("float")
isv_base = ISVBase(ubm, 2) isv_base = ISVBase(ubm, 2)
isv_base.u = numpy.arange(18).reshape(9,2).astype("float") isv_base.u = numpy.arange(18).reshape(9, 2).astype("float")
isv_base.d = numpy.arange(9).astype("float") isv_base.d = numpy.arange(9).astype("float")
isv_base_after_pickle = pickle.loads(pickle.dumps(isv_base)) isv_base_after_pickle = pickle.loads(pickle.dumps(isv_base))
...@@ -33,32 +46,36 @@ def test_isv_base(): ...@@ -33,32 +46,36 @@ def test_isv_base():
def test_isv_machine(): def test_isv_machine():
# Creates a UBM # Creates a UBM
weights = numpy.array([0.4, 0.6], 'float64') weights = numpy.array([0.4, 0.6], "float64")
means = numpy.array([[1, 6, 2], [4, 3, 2]], 'float64') means = numpy.array([[1, 6, 2], [4, 3, 2]], "float64")
variances = numpy.array([[1, 2, 1], [2, 1, 2]], 'float64') variances = numpy.array([[1, 2, 1], [2, 1, 2]], "float64")
ubm = GMMMachine(2,3) ubm = GMMMachine(2, 3)
ubm.weights = weights ubm.weights = weights
ubm.means = means ubm.means = means
ubm.variances = variances ubm.variances = variances
# Creates a ISVBaseMachine # Creates a ISVBaseMachine
U = numpy.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], 'float64') U = numpy.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], "float64")
#V = numpy.array([[0], [0], [0], [0], [0], [0]], 'float64') # V = numpy.array([[0], [0], [0], [0], [0], [0]], 'float64')
d = numpy.array([0, 1, 0, 1, 0, 1], 'float64') d = numpy.array([0, 1, 0, 1, 0, 1], "float64")
base = ISVBase(ubm,2) base = ISVBase(ubm, 2)
base.u = U base.u = U
base.d = d base.d = d
# Creates a ISVMachine # Creates a ISVMachine
z = numpy.array([3,4,1,2,0,1], 'float64') z = numpy.array([3, 4, 1, 2, 0, 1], "float64")
x = numpy.array([1,2], 'float64') x = numpy.array([1, 2], "float64")
isv_machine = ISVMachine(base) isv_machine = ISVMachine(base)
isv_machine.z = z isv_machine.z = z
isv_machine.x = x isv_machine.x = x
isv_machine_after_pickle = pickle.loads(pickle.dumps(isv_machine)) isv_machine_after_pickle = pickle.loads(pickle.dumps(isv_machine))
assert numpy.allclose(isv_machine_after_pickle.isv_base.u, isv_machine.isv_base.u, 10e-3) assert numpy.allclose(
assert numpy.allclose(isv_machine_after_pickle.isv_base.d, isv_machine.isv_base.d, 10e-3) isv_machine_after_pickle.isv_base.u, isv_machine.isv_base.u, 10e-3
)
assert numpy.allclose(
isv_machine_after_pickle.isv_base.d, isv_machine.isv_base.d, 10e-3
)
assert numpy.allclose(isv_machine_after_pickle.x, isv_machine.x, 10e-3) assert numpy.allclose(isv_machine_after_pickle.x, isv_machine.x, 10e-3)
assert numpy.allclose(isv_machine_after_pickle.z, isv_machine.z, 10e-3) assert numpy.allclose(isv_machine_after_pickle.z, isv_machine.z, 10e-3)
...@@ -66,30 +83,60 @@ def test_isv_machine(): ...@@ -66,30 +83,60 @@ def test_isv_machine():
def test_kmeans_machine(): def test_kmeans_machine():
# Test a KMeansMachine # Test a KMeansMachine
means = numpy.array([[3, 70, 0], [4, 72, 0]], 'float64') means = numpy.array([[3, 70, 0], [4, 72, 0]], "float64")
mean = numpy.array([3,70,1], 'float64') mean = numpy.array([3, 70, 1], "float64")
# Initializes a KMeansMachine # Initializes a KMeansMachine
kmeans_machine = KMeansMachine(2,3) kmeans_machine = KMeansMachine(2, 3)
kmeans_machine.means = means kmeans_machine.means = means
kmeans_machine_after_pickle = pickle.loads(pickle.dumps(kmeans_machine)) kmeans_machine_after_pickle = pickle.loads(pickle.dumps(kmeans_machine))
assert numpy.allclose(kmeans_machine_after_pickle.means, kmeans_machine.means, 10e-3) assert numpy.allclose(
kmeans_machine_after_pickle.means, kmeans_machine.means, 10e-3
)
def test_gmmstats(): def test_gmmstats():
gs = GMMStats(2,3) gs = GMMStats(2, 3)
log_likelihood = -3. log_likelihood = -3.0
T = 1 T = 1
n = numpy.array([0.4, 0.6], numpy.float64) n = numpy.array([0.4, 0.6], numpy.float64)
sumpx = numpy.array([[1., 2., 3.], [2., 4., 3.]], numpy.float64) sumpx = numpy.array([[1.0, 2.0, 3.0], [2.0, 4.0, 3.0]], numpy.float64)
sumpxx = numpy.array([[10., 20., 30.], [40., 50., 60.]], numpy.float64) sumpxx = numpy.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], numpy.float64)
gs.log_likelihood = log_likelihood gs.log_likelihood = log_likelihood
gs.t = T gs.t = T
gs.n = n gs.n = n
gs.sum_px = sumpx gs.sum_px = sumpx
gs.sum_pxx = sumpxx gs.sum_pxx = sumpxx
gs_after_pickle = pickle.loads(pickle.dumps(gs)) gs_after_pickle = pickle.loads(pickle.dumps(gs))
assert gs == gs_after_pickle assert gs == gs_after_pickle
def test_ivector_machine():
# Ubm
ubm = GMMMachine(2, 3)
ubm.weights = numpy.array([0.4, 0.6])
ubm.means = numpy.array([[1.0, 7, 4], [4, 5, 3]])
ubm.variances = numpy.array([[0.5, 1.0, 1.5], [1.0, 1.5, 2.0]])
ivector_machine = IVectorMachine(ubm, 2)
t = numpy.array([[1.0, 2], [4, 1], [0, 3], [5, 8], [7, 10], [11, 1]])
sigma = numpy.array([1.0, 2.0, 1.0, 3.0, 2.0, 4.0])
ivector_machine.t = t
ivector_machine.sigma = sigma
ivector_after_pickle = pickle.loads(pickle.dumps(ivector_machine))
assert numpy.allclose(ivector_after_pickle.sigma, ivector_machine.sigma, 10e-3)
assert numpy.allclose(ivector_after_pickle.t, ivector_machine.t, 10e-3)
assert numpy.allclose(
ivector_after_pickle.ubm.means, ivector_machine.ubm.means, 10e-3
)
assert numpy.allclose(
ivector_after_pickle.ubm.variances, ivector_machine.ubm.variances, 10e-3
)
assert numpy.allclose(
ivector_after_pickle.ubm.weights, ivector_machine.ubm.weights, 10e-3
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment