diff --git a/xbob/learn/mlp/shuffler.cpp b/xbob/learn/mlp/shuffler.cpp index 766249ba7af06957014bfd4c4773242a3c1b5392..71e6e35d292c32fd663c275b381737e938c456a1 100644 --- a/xbob/learn/mlp/shuffler.cpp +++ b/xbob/learn/mlp/shuffler.cpp @@ -11,6 +11,7 @@ #include <xbob.blitz/cppapi.h> #include <xbob.blitz/cleanup.h> #include <xbob.learn.mlp/api.h> +#include <xbob.core/random.h> #include <structmember.h> /********************************************* @@ -155,6 +156,172 @@ int PyBobLearnDataShuffler_Check(PyObject* o) { return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnDataShuffler_Type)); } +PyDoc_STRVAR(s_draw_str, "draw"); +PyDoc_STRVAR(s_draw_doc, +"o.draw([n, [data, [target, [rng]]]]) -> (data, target)\n\ +\n\ +Draws a random number of data-target pairs from the input data.\n\ +\n\ +This method will draw a given number ``n`` of data-target pairs\n\ +from the input data, randomly. You can specific the destination\n\ +containers ``data`` and ``target`` which, if provided, must be 2D\n\ +arrays of type `float64`` with as many rows as ``n`` and as\n\ +many columns as the data and target widths provided upon\n\ +construction.\n\ +\n\ +If ``n`` is not specified, than that value is taken from the\n\ +number of rows in either ``data`` or ``target``, whichever is\n\ +provided. It is an error not to provide one of ``data``,\n\ +``target`` or ``n``.\n\ +\n\ +If a random generator ``rng`` is provided, it must of the type\n\ +:py:class:`xbob.core.random.mt19937`. In this case, the shuffler\n\ +is going to use this generator instead of its internal one. This\n\ +mechanism is useful for repeating draws in case of tests.\n\ +\n\ +Independently if ``data`` and/or ``target`` is provided, this\n\ +function will always return a tuple containing the ``data`` and\n\ +``target`` arrays with the random data picked from the user\n\ +input. If either ``data`` or ``target`` are not provided by\n\ +the user, then they are created internally and returned.\n\ +"); + +static PyObject* PyBobLearnDataShuffler_Call +(PyBobLearnDataShufflerObject* self, PyObject* args, PyObject* kwds) { + + static const char* const_kwlist[] = {"n", "data", "target", "rng", 0}; + static char** kwlist = const_cast<char**>(const_kwlist); + + Py_ssize_t n = 0; + PyBlitzArrayObject* data = 0; + PyBlitzArrayObject* target = 0; + PyBoostMt19937Object* rng = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|nO&O&O!", kwlist, + &n, + &PyBlitzArray_OutputConverter, &data, + &PyBlitzArray_OutputConverter, &target, + &PyBoostMt19937_Type, &rng + )) return 0; + + //protects acquired resources through this scope + auto data_ = make_xsafe(data); + auto target_ = make_xsafe(target); + + //checks data and target first + if (data) { + if (data->ndim != 2 || data->type_num != NPY_FLOAT64) { + PyErr_Format(PyExc_TypeError, "`%s' functor requires you pass a 2D array with 64-bit floats for input `data', but you passed a %" PY_FORMAT_SIZE_T "dD array with `%s' data type", Py_TYPE(self)->tp_name, data->ndim, PyBlitzArray_TypenumAsString(data->type_num)); + return 0; + } + if (data->shape[1] != (Py_ssize_t)self->cxx->getDataWidth()) { + PyErr_Format(PyExc_RuntimeError, "`%s' functor requires you pass a 2D array with %" PY_FORMAT_SIZE_T "d columns for input `data', but you passed an array with %" PY_FORMAT_SIZE_T "d columns instead", Py_TYPE(self)->tp_name, self->cxx->getDataWidth(), data->shape[1]); + return 0; + } + } + + if (target) { + if (target->ndim != 2 || target->type_num != NPY_FLOAT64) { + PyErr_Format(PyExc_TypeError, "`%s' functor requires you pass a 2D array with 64-bit floats for input `target', but you passed a %" PY_FORMAT_SIZE_T "dD array with `%s' target type", Py_TYPE(self)->tp_name, target->ndim, PyBlitzArray_TypenumAsString(target->type_num)); + return 0; + } + if (target->shape[1] != (Py_ssize_t)self->cxx->getTargetWidth()) { + PyErr_Format(PyExc_RuntimeError, "`%s' functor requires you pass a 2D array with %" PY_FORMAT_SIZE_T "d columns for input `target', but you passed an array with %" PY_FORMAT_SIZE_T "d columns instead", Py_TYPE(self)->tp_name, self->cxx->getDataWidth(), target->shape[1]); + return 0; + } + } + + if (data && target) { + //make sure that the number of rows match + if (data->shape[0] != target->shape[0]) { + PyErr_Format(PyExc_RuntimeError, "`%s' functor requires you pass 2D arrays for both `data' and `target' with the same number of rows, but `data' has %" PY_FORMAT_SIZE_T "d rows and `target' has %" PY_FORMAT_SIZE_T "d rows instead", Py_TYPE(self)->tp_name, data->shape[0], target->shape[0]); + return 0; + } + } + + Py_ssize_t array_length = 0; + if (data) array_length = data->shape[0]; + if (target) array_length = target->shape[0]; + + if (n && array_length) { + if (n != array_length) { + PyErr_Format(PyExc_RuntimeError, "`%s' functor requires you pass 2D arrays for both `data' and `target' with the same number of rows. If a value for `n' is passed, it should match the number of rows in both `data' and `target', but `data' and/or `target' have %" PY_FORMAT_SIZE_T "d rows and `n' is set to %" PY_FORMAT_SIZE_T "d instead - tip: you don't need to specific `n' in this case", Py_TYPE(self)->tp_name, array_length, n); + return 0; + } + } + + if (!n && !array_length) { + PyErr_Format(PyExc_RuntimeError, "`%s' functor expects you either pass `n', for the number of samples to return or `data' and/or `target' arrays to be filled, but you passed neither.", Py_TYPE(self)->tp_name); + return 0; + } + + if (!n) n = array_length; + + //allocates data, if not already there + if (!data) { + Py_ssize_t shape[2]; + shape[0] = n; + shape[1] = self->cxx->getDataWidth(); + data = (PyBlitzArrayObject*)PyBlitzArray_SimpleNew(NPY_FLOAT64, 2, shape); + if (!data) return 0; + data_ = make_safe(data); + } + + //allocates target, if not already there + if (!target) { + Py_ssize_t shape[2]; + shape[0] = n; + shape[1] = self->cxx->getDataWidth(); + target = (PyBlitzArrayObject*)PyBlitzArray_SimpleNew(NPY_FLOAT64, 2, shape); + if (!target) return 0; + target_ = make_safe(target); + } + + //all good, now call the shuffler + try { + if (rng) { + self->cxx->operator()( + *rng->rng, + *PyBlitzArrayCxx_AsBlitz<double,2>(data), + *PyBlitzArrayCxx_AsBlitz<double,2>(target) + ); + } + else { + self->cxx->operator()( + *PyBlitzArrayCxx_AsBlitz<double,2>(data), + *PyBlitzArrayCxx_AsBlitz<double,2>(target) + ); + } + } + catch (std::exception& ex) { + PyErr_SetString(PyExc_RuntimeError, ex.what()); + return 0; + } + catch (...) { + PyErr_Format(PyExc_RuntimeError, "cannot call object of type `%s' - unknown exception thrown", Py_TYPE(self)->tp_name); + return 0; + } + + //and finally we return... + Py_INCREF(data); + Py_INCREF(target); + return Py_BuildValue("OO", + PyBlitzArray_NUMPY_WRAP(reinterpret_cast<PyObject*>(data)), + PyBlitzArray_NUMPY_WRAP(reinterpret_cast<PyObject*>(target)) + ); + +} + +static PyMethodDef PyBobLearnDataShuffler_methods[] = { + { + s_draw_str, + (PyCFunction)PyBobLearnDataShuffler_Call, + METH_VARARGS|METH_KEYWORDS, + s_draw_doc + }, + {0} /* Sentinel */ +}; + PyTypeObject PyBobLearnDataShuffler_Type = { PyVarObject_HEAD_INIT(0, 0) s_shuffler_str, /* tp_name */ @@ -170,7 +337,7 @@ PyTypeObject PyBobLearnDataShuffler_Type = { 0, /* tp_as_sequence */ 0, /* tp_as_mapping */ 0, /* tp_hash */ - 0, /* tp_call */ + (ternaryfunc)PyBobLearnDataShuffler_Call, /* tp_call */ 0, /* tp_str */ 0, /* tp_getattro */ 0, /* tp_setattro */ @@ -183,7 +350,7 @@ PyTypeObject PyBobLearnDataShuffler_Type = { 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ - 0, /* tp_methods */ + PyBobLearnDataShuffler_methods, /* tp_methods */ 0, /* tp_members */ 0, /* tp_getset */ 0, /* tp_base */ diff --git a/xbob/learn/mlp/test_shuffler.py b/xbob/learn/mlp/test_shuffler.py index ffc15b087b0d620d8db94f63c6df5b559c52dd29..3bff296756f7f01e6dbf4ecb94d01827396032f7 100644 --- a/xbob/learn/mlp/test_shuffler.py +++ b/xbob/learn/mlp/test_shuffler.py @@ -152,8 +152,8 @@ def test_seeding(): rng1 = xbob.core.random.mt19937(32) rng2 = xbob.core.random.mt19937(32) - [data1, target1] = shuffle1(rng1, N) - [data2, target2] = shuffle2(rng2, N) + [data1, target1] = shuffle1(N, rng=rng1) + [data2, target2] = shuffle2(N, rng=rng2) assert (data1 == data2).all()