Pickling KMeans

parent 4cfcee92
Pipeline #44943 failed with stage
in 2 minutes and 24 seconds
......@@ -13,29 +13,13 @@ from ._library import *
from ._library import GMMMachine as _GMMMachine_C
from ._library import ISVBase as _ISVBase_C
from ._library import ISVMachine as _ISVMachine_C
from ._library import KMeansMachine as _KMeansMachine_C
from . import version
from .version import module as __version__
from .version import api as __api_version__
from .train import *
def ztnorm_same_value(vect_a, vect_b):
"""Computes the matrix of boolean D for the ZT-norm, which indicates where
the client ids of the T-Norm models and Z-Norm samples match.
vect_a An (ordered) list of client_id corresponding to the T-Norm models
vect_b An (ordered) list of client_id corresponding to the Z-Norm impostor samples
import numpy
sameMatrix = numpy.ndarray((len(vect_a), len(vect_b)), "bool")
for j in range(len(vect_a)):
for i in range(len(vect_b)):
sameMatrix[j, i] = vect_a[j] == vect_b[i]
return sameMatrix
def get_config():
"""Returns a string containing the configuration information.
......@@ -149,3 +133,23 @@ class ISVMachine(_ISVMachine_C):
def __setstate__(self, d):
self.__dict__ = d
class KMeansMachine(_KMeansMachine_C):
__doc__ = _KMeansMachine_C.__doc__
def to_dict(kmeans_machine):
kmeans_data = dict()
kmeans_data["means"] = kmeans_machine.means
return kmeans_data
def __getstate__(self):
d = dict(self.__dict__)
return d
def __setstate__(self, d):
means = d["means"]
self.__init__(means.shape[0], means.shape[1])
self.means = means
......@@ -843,7 +843,7 @@ bool init_BobLearnEMKMeansMachine(PyObject* module)
// initialize the type struct
PyBobLearnEMKMeansMachine_Type.tp_name = KMeansMachine_doc.name();
PyBobLearnEMKMeansMachine_Type.tp_basicsize = sizeof(PyBobLearnEMKMeansMachineObject);
PyBobLearnEMKMeansMachine_Type.tp_flags = Py_TPFLAGS_DEFAULT;
PyBobLearnEMKMeansMachine_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
PyBobLearnEMKMeansMachine_Type.tp_doc = KMeansMachine_doc.doc();
// set the functions
......@@ -2,7 +2,7 @@
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
from bob.learn.em import GMMMachine, ISVBase, ISVMachine
from bob.learn.em import GMMMachine, ISVBase, ISVMachine, KMeansMachine
import numpy
import pickle
......@@ -61,3 +61,17 @@ def test_isv_machine():
assert numpy.allclose(isv_machine_after_pickle.isv_base.d, isv_machine.isv_base.d, 10e-3)
assert numpy.allclose(isv_machine_after_pickle.x, isv_machine.x, 10e-3)
assert numpy.allclose(isv_machine_after_pickle.z, isv_machine.z, 10e-3)
def test_kmeans_machine():
# Test a KMeansMachine
means = numpy.array([[3, 70, 0], [4, 72, 0]], 'float64')
mean = numpy.array([3,70,1], 'float64')
# Initializes a KMeansMachine
kmeans_machine = KMeansMachine(2,3)
kmeans_machine.means = means
kmeans_machine_after_pickle = pickle.loads(pickle.dumps(kmeans_machine))
assert numpy.allclose(kmeans_machine_after_pickle.means, kmeans_machine.means, 10e-3)
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment