From 16d4d493f3b1becace723c6889998005c3a8e9ef Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Tue, 13 May 2014 09:26:54 +0200
Subject: [PATCH] Finish full bindings for MLPs

---
 xbob/learn/mlp/rprop.cpp | 88 ++++++++++++++++++++++++++++++++++++++++
 1 file changed, 88 insertions(+)

diff --git a/xbob/learn/mlp/rprop.cpp b/xbob/learn/mlp/rprop.cpp
index 939b919..43f58d5 100644
--- a/xbob/learn/mlp/rprop.cpp
+++ b/xbob/learn/mlp/rprop.cpp
@@ -730,6 +730,82 @@ static PyObject* PyBobLearnMLPRProp_setPreviousBiasDerivativeOnLayer
 
 }
 
+PyDoc_STRVAR(s_set_delta_str, "set_delta");
+PyDoc_STRVAR(s_set_delta_doc,
+    "Sets the delta for a given weight layer.");
+
+static PyObject* PyBobLearnMLPRProp_setDeltaOnLayer
+(PyBobLearnMLPRPropObject* self, PyObject* args, PyObject* kwds) {
+
+  /* Parses input arguments in a single shot */
+  static const char* const_kwlist[] = {"array", "layer", 0};
+  static char** kwlist = const_cast<char**>(const_kwlist);
+
+  PyBlitzArrayObject* array = 0;
+  Py_ssize_t layer = 0;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&n", kwlist,
+        &PyBlitzArray_Converter, &array, &layer)) return 0;
+
+  if (array->type_num != NPY_FLOAT64 || array->ndim != 2) {
+    PyErr_Format(PyExc_TypeError, "`%s.%s' only supports 2D 64-bit float arrays for argument `array' (or any other object coercible to that), but you provided an object with %" PY_FORMAT_SIZE_T "d dimensions and with type `%s' which is not compatible - check your input", Py_TYPE(self)->tp_name, s_set_delta_str, array->ndim, PyBlitzArray_TypenumAsString(array->type_num));
+    return 0;
+  }
+
+  try {
+    self->cxx->setDelta(*PyBlitzArrayCxx_AsBlitz<double,2>(array), layer);
+  }
+  catch (std::exception& ex) {
+    PyErr_SetString(PyExc_RuntimeError, ex.what());
+    return 0;
+  }
+  catch (...) {
+    PyErr_Format(PyExc_RuntimeError, "cannot set delta at layer %" PY_FORMAT_SIZE_T "d for `%s': unknown exception caught", layer, Py_TYPE(self)->tp_name);
+    return 0;
+  }
+
+  Py_RETURN_NONE;
+
+}
+
+PyDoc_STRVAR(s_set_bias_delta_str, "set_bias_delta");
+PyDoc_STRVAR(s_set_bias_delta_doc,
+    "Sets the bias delta for a given bias layer.");
+
+static PyObject* PyBobLearnMLPRProp_setBiasDeltaOnLayer
+(PyBobLearnMLPRPropObject* self, PyObject* args, PyObject* kwds) {
+
+  /* Parses input arguments in a single shot */
+  static const char* const_kwlist[] = {"array", "layer", 0};
+  static char** kwlist = const_cast<char**>(const_kwlist);
+
+  PyBlitzArrayObject* array = 0;
+  Py_ssize_t layer = 0;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&n", kwlist,
+        &PyBlitzArray_Converter, &array, &layer)) return 0;
+
+  if (array->type_num != NPY_FLOAT64 || array->ndim != 1) {
+    PyErr_Format(PyExc_TypeError, "`%s.%s' only supports 1D 64-bit float arrays for argument `array' (or any other object coercible to that), but you provided an object with %" PY_FORMAT_SIZE_T "d dimensions and with type `%s' which is not compatible - check your input", Py_TYPE(self)->tp_name, s_set_bias_delta_str, array->ndim, PyBlitzArray_TypenumAsString(array->type_num));
+    return 0;
+  }
+
+  try {
+    self->cxx->setBiasDelta(*PyBlitzArrayCxx_AsBlitz<double,1>(array), layer);
+  }
+  catch (std::exception& ex) {
+    PyErr_SetString(PyExc_RuntimeError, ex.what());
+    return 0;
+  }
+  catch (...) {
+    PyErr_Format(PyExc_RuntimeError, "cannot set bias delta at layer %" PY_FORMAT_SIZE_T "d for `%s': unknown exception caught", layer, Py_TYPE(self)->tp_name);
+    return 0;
+  }
+
+  Py_RETURN_NONE;
+
+}
+
 static PyMethodDef PyBobLearnMLPRProp_methods[] = {
   {
     s_reset_str,
@@ -755,6 +831,18 @@ static PyMethodDef PyBobLearnMLPRProp_methods[] = {
     METH_VARARGS|METH_KEYWORDS,
     s_set_previous_bias_derivative_doc,
   },
+  {
+    s_set_delta_str,
+    (PyCFunction)PyBobLearnMLPRProp_setDeltaOnLayer,
+    METH_VARARGS|METH_KEYWORDS,
+    s_set_delta_doc,
+  },
+  {
+    s_set_bias_delta_str,
+    (PyCFunction)PyBobLearnMLPRProp_setBiasDeltaOnLayer,
+    METH_VARARGS|METH_KEYWORDS,
+    s_set_bias_delta_doc,
+  },
   {0} /* Sentinel */
 };
 
-- 
GitLab