Pickling ISV Base

parent 322070d8
Pipeline #39197 failed with stage
in 19 minutes and 32 seconds
......@@ -6,10 +6,12 @@ import bob.sp
# import our own Library
import bob.extension
bob.extension.load_bob_library('bob.learn.em', __file__)
bob.extension.load_bob_library("bob.learn.em", __file__)
from ._library import *
from ._library import GMMMachine as _GMMMachine_C
from ._library import ISVBase as _ISVBase_C
from . import version
from .version import module as __version__
......@@ -18,44 +20,95 @@ from .train import *
def ztnorm_same_value(vect_a, vect_b):
"""Computes the matrix of boolean D for the ZT-norm, which indicates where
"""Computes the matrix of boolean D for the ZT-norm, which indicates where
the client ids of the T-Norm models and Z-Norm samples match.
vect_a An (ordered) list of client_id corresponding to the T-Norm models
vect_b An (ordered) list of client_id corresponding to the Z-Norm impostor samples
"""
import numpy
sameMatrix = numpy.ndarray((len(vect_a), len(vect_b)), 'bool')
for j in range(len(vect_a)):
for i in range(len(vect_b)):
sameMatrix[j, i] = (vect_a[j] == vect_b[i])
return sameMatrix
"""
import numpy
sameMatrix = numpy.ndarray((len(vect_a), len(vect_b)), "bool")
for j in range(len(vect_a)):
for i in range(len(vect_b)):
sameMatrix[j, i] = vect_a[j] == vect_b[i]
return sameMatrix
def get_config():
"""Returns a string containing the configuration information.
"""
return bob.extension.get_config(__name__, version.externals, version.api)
"""Returns a string containing the configuration information.
"""
return bob.extension.get_config(__name__, version.externals, version.api)
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
__all__ = [_ for _ in dir() if not _.startswith("_")]
class GMMMachine(_GMMMachine_C):
__doc__ = _GMMMachine_C.__doc__
def __getstate__(self):
def update_dict(self, d):
self.means = d["means"]
self.variances = d["variances"]
self.means = d["means"]
@staticmethod
def gmm_shape_from_dict(d):
return d["means"].shape
@classmethod
def create_from_dict(cls, d):
shape = GMMMachine.gmm_shape_from_dict(d)
gmm_machine = cls(shape[0], shape[1])
gmm_machine.update_dict(d)
return gmm_machine
@staticmethod
def to_dict(gmm_machine):
gmm_data = dict()
gmm_data["means"] = gmm_machine.means
gmm_data["variances"] = gmm_machine.variances
gmm_data["weights"] = gmm_machine.weights
return gmm_data
def __getstate__(self):
d = dict(self.__dict__)
d["means"] = self.means
d["variances"] = self.variances
d["weights"] = self.weights
d.update(self.__class__.to_dict(self))
return d
def __setstate__(self, d):
means = d["means"]
self.__init__(means.shape[0], means.shape[1])
self.means = means
self.variances = d["variances"]
self.means = d["means"]
self.__dict__ = d
shape = self.gmm_shape_from_dict(d)
self.__init__(shape[0], shape[1])
self.update_dict(d)
class ISVBase(_ISVBase_C):
__doc__ = _ISVBase_C.__doc__
@staticmethod
def to_dict(isv_base):
isv_data = dict()
isv_data["gmm"] = GMMMachine.to_dict(isv_base.ubm)
isv_data["u"] = isv_base.u
isv_data["d"] = isv_base.d
return isv_data
def update_dict(self, d):
ubm = GMMMachine.create_from_dict(d["gmm"])
u = d["u"]
self.__init__(ubm, u.shape[1])
self.u = u
self.d = d["d"]
def __getstate__(self):
d = dict(self.__dict__)
d.update(self.__class__.to_dict(self))
return d
def __setstate__(self, d):
self.__dict__ = d
self.update_dict(d)
......@@ -539,7 +539,7 @@ bool init_BobLearnEMISVBase(PyObject* module)
// initialize the type struct
PyBobLearnEMISVBase_Type.tp_name = ISVBase_doc.name();
PyBobLearnEMISVBase_Type.tp_basicsize = sizeof(PyBobLearnEMISVBaseObject);
PyBobLearnEMISVBase_Type.tp_flags = Py_TPFLAGS_DEFAULT;
PyBobLearnEMISVBase_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
PyBobLearnEMISVBase_Type.tp_doc = ISVBase_doc.doc();
// set the functions
......
......@@ -648,7 +648,7 @@ bool init_BobLearnEMISVMachine(PyObject* module)
// initialize the type struct
PyBobLearnEMISVMachine_Type.tp_name = ISVMachine_doc.name();
PyBobLearnEMISVMachine_Type.tp_basicsize = sizeof(PyBobLearnEMISVMachineObject);
PyBobLearnEMISVMachine_Type.tp_flags = Py_TPFLAGS_DEFAULT;
PyBobLearnEMISVMachine_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
PyBobLearnEMISVMachine_Type.tp_doc = ISVMachine_doc.doc();
// set the functions
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>import numpy
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
from bob.pipelines.utils import assert_picklable
from bob.learn.em import GMMMachine
from bob.learn.em import GMMMachine, ISVBase
from .test_em import equals
import numpy
import pickle
......@@ -17,3 +17,16 @@ def test_gmm_machine():
assert equals(gmm_machine_after_pickle.means, gmm_machine_after_pickle.means, 10e-3)
assert equals(gmm_machine_after_pickle.variances, gmm_machine_after_pickle.variances, 10e-3)
assert equals(gmm_machine_after_pickle.weights, gmm_machine_after_pickle.weights, 10e-3)
def test_isv():
ubm = GMMMachine(3,3)
ubm.means = numpy.arange(9).reshape(3,3).astype("float")
isv_base = ISVBase(ubm, 2)
isv_base.u = numpy.arange(18).reshape(9,2).astype("float")
isv_base.d = numpy.arange(9).astype("float")
isv_base_after_pickle = pickle.loads(pickle.dumps(isv_base))
assert equals(isv_base.u, isv_base_after_pickle.u, 10e-3)
assert equals(isv_base.d, isv_base_after_pickle.d, 10e-3)
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment