Skip to content
Snippets Groups Projects
Commit fb8d501d authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Pickling KMeans

parent 4cfcee92
No related branches found
No related tags found
1 merge request!39Pickling Objects
Pipeline #44943 failed
......@@ -13,29 +13,13 @@ from ._library import *
from ._library import GMMMachine as _GMMMachine_C
from ._library import ISVBase as _ISVBase_C
from ._library import ISVMachine as _ISVMachine_C
from ._library import KMeansMachine as _KMeansMachine_C
from . import version
from .version import module as __version__
from .version import api as __api_version__
from .train import *
def ztnorm_same_value(vect_a, vect_b):
"""Computes the matrix of boolean D for the ZT-norm, which indicates where
the client ids of the T-Norm models and Z-Norm samples match.
vect_a An (ordered) list of client_id corresponding to the T-Norm models
vect_b An (ordered) list of client_id corresponding to the Z-Norm impostor samples
"""
import numpy
sameMatrix = numpy.ndarray((len(vect_a), len(vect_b)), "bool")
for j in range(len(vect_a)):
for i in range(len(vect_b)):
sameMatrix[j, i] = vect_a[j] == vect_b[i]
return sameMatrix
def get_config():
"""Returns a string containing the configuration information.
"""
......@@ -149,3 +133,23 @@ class ISVMachine(_ISVMachine_C):
def __setstate__(self, d):
self.__dict__ = d
self.update_dict(d)
class KMeansMachine(_KMeansMachine_C):
__doc__ = _KMeansMachine_C.__doc__
@staticmethod
def to_dict(kmeans_machine):
kmeans_data = dict()
kmeans_data["means"] = kmeans_machine.means
return kmeans_data
def __getstate__(self):
d = dict(self.__dict__)
d.update(self.__class__.to_dict(self))
return d
def __setstate__(self, d):
means = d["means"]
self.__init__(means.shape[0], means.shape[1])
self.means = means
......@@ -843,7 +843,7 @@ bool init_BobLearnEMKMeansMachine(PyObject* module)
// initialize the type struct
PyBobLearnEMKMeansMachine_Type.tp_name = KMeansMachine_doc.name();
PyBobLearnEMKMeansMachine_Type.tp_basicsize = sizeof(PyBobLearnEMKMeansMachineObject);
PyBobLearnEMKMeansMachine_Type.tp_flags = Py_TPFLAGS_DEFAULT;
PyBobLearnEMKMeansMachine_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
PyBobLearnEMKMeansMachine_Type.tp_doc = KMeansMachine_doc.doc();
// set the functions
......
......@@ -2,7 +2,7 @@
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
from bob.learn.em import GMMMachine, ISVBase, ISVMachine
from bob.learn.em import GMMMachine, ISVBase, ISVMachine, KMeansMachine
import numpy
import pickle
......@@ -61,3 +61,17 @@ def test_isv_machine():
assert numpy.allclose(isv_machine_after_pickle.isv_base.d, isv_machine.isv_base.d, 10e-3)
assert numpy.allclose(isv_machine_after_pickle.x, isv_machine.x, 10e-3)
assert numpy.allclose(isv_machine_after_pickle.z, isv_machine.z, 10e-3)
def test_kmeans_machine():
# Test a KMeansMachine
means = numpy.array([[3, 70, 0], [4, 72, 0]], 'float64')
mean = numpy.array([3,70,1], 'float64')
# Initializes a KMeansMachine
kmeans_machine = KMeansMachine(2,3)
kmeans_machine.means = means
kmeans_machine_after_pickle = pickle.loads(pickle.dumps(kmeans_machine))
assert numpy.allclose(kmeans_machine_after_pickle.means, kmeans_machine.means, 10e-3)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment