Made bob.learn.linear.Machine picklable

parent 2738dc2e
Pipeline #39205 failed with stage
in 8 minutes and 30 seconds
......@@ -2,6 +2,7 @@
import bob.io.base
import bob.math
import bob.learn.activation
import numpy
# import our own Library
import bob.extension
......@@ -11,6 +12,7 @@ from ._library import *
from . import version
from .version import module as __version__
from .version import api as __api_version__
from ._library import Machine as _Machine_C
from .auxiliary import *
from .GFK import GFKMachine, GFKTrainer
......@@ -22,3 +24,35 @@ def get_config():
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
class Machine(_Machine_C):
__doc__ = _Machine_C.__doc__
def update_dict(self, d):
self.input_sub = numpy.array([d["input_sub"]])
self.input_div = numpy.array([d["input_div"]])
@classmethod
def create_from_dict(cls, d):
machine = cls(numpy.array(d["weights"]))
machine.update_dict(d)
return machine
@staticmethod
def to_dict(machine):
machine_data = dict()
machine_data["input_sub"] = machine.input_sub
machine_data["input_div"] = machine.input_div
machine_data["weights"] = machine.weights
return machine_data
def __getstate__(self):
d = dict(self.__dict__)
d.update(self.__class__.to_dict(self))
return d
def __setstate__(self, d):
self.__dict__ = d
self.__init__(numpy.array(d["weights"]))
self.update_dict(d)
......@@ -857,7 +857,7 @@ bool init_BobLearnLinearMachine(PyObject* module)
// Linear Machine
PyBobLearnLinearMachine_Type.tp_name = Machine_doc.name();
PyBobLearnLinearMachine_Type.tp_basicsize = sizeof(PyBobLearnLinearMachineObject);
PyBobLearnLinearMachine_Type.tp_flags = Py_TPFLAGS_DEFAULT;
PyBobLearnLinearMachine_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
PyBobLearnLinearMachine_Type.tp_doc = Machine_doc.doc();
// set the functions
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
from bob.pipelines.utils import assert_picklable
from bob.learn.linear import Machine
import numpy
import pickle
def test_machine():
machine = Machine(10,3)
machine.weights = numpy.arange(30).reshape(10,3).astype("float")
machine.input_div = numpy.arange(3).astype("float")
machine.input_sub = numpy.arange(3).astype("float")
machine_after_pickle = pickle.loads(pickle.dumps(machine))
assert numpy.allclose(machine.weights, machine_after_pickle.weights, 10e-3)
assert numpy.allclose(machine.input_div, machine_after_pickle.input_div, 10e-3)
assert numpy.allclose(machine.input_sub, machine_after_pickle.input_sub, 10e-3)
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment