Commit dbeacb4e authored by André Anjos's avatar André Anjos 💬

Finished implementation of the Machine bit

parent ce453f7d
......@@ -298,6 +298,222 @@ PyObject* PyBobLearnLibsvmFile_eof(PyBobLearnLibsvmFileObject* self) {
Py_RETURN_FALSE;
}
PyDoc_STRVAR(s_read_str, "read");
PyDoc_STRVAR(s_read_doc,
"o.read([values]) -> (int, array)\n\
\n\
Reads a single line from the file and returns a tuple\n\
containing the label and a numpy array of ``float64``\n\
elements. The :py:class:`numpy.ndarray` has a shape\n\
as defined by the :py:attr:`File.shape` attribute of\n\
the current file. If the file has finished, this method\n\
returns ``None`` instead.\n\
\n\
If the output array ``values`` is provided, it must be a\n\
64-bit float array with a shape matching the file shape as\n\
defined by :py:attr:`File.shape`. Providing an output\n\
array avoids constant memory re-allocation.\n\
\n\
");
static PyObject* PyBobLearnLibsvmFile_read
(PyBobLearnLibsvmFileObject* self, PyObject* args, PyObject* kwds) {
// before doing anything, check file status and returns if that is the case
if (!self->cxx->good()) Py_RETURN_NONE;
static const char* const_kwlist[] = {"values", 0};
static char** kwlist = const_cast<char**>(const_kwlist);
PyBlitzArrayObject* values = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O&", kwlist,
&PyBlitzArray_OutputConverter, &values
)) return 0;
//protects acquired resources through this scope
auto values_ = make_xsafe(values);
if (values && values->type_num != NPY_FLOAT64) {
PyErr_Format(PyExc_TypeError, "`%s' only supports 64-bit float arrays for output array `values'", Py_TYPE(self)->tp_name);
return 0;
}
if (values && values->ndim != 1) {
PyErr_Format(PyExc_RuntimeError, "Output arrays should always be 1D but you provided an object with %" PY_FORMAT_SIZE_T "d dimensions", values->ndim);
return 0;
}
if (values && values->shape[0] != (Py_ssize_t)self->cxx->shape()) {
PyErr_Format(PyExc_RuntimeError, "1D `values' array should have %" PY_FORMAT_SIZE_T "d elements matching the shape of this file, not %" PY_FORMAT_SIZE_T "d rows", self->cxx->shape(), values->shape[0]);
return 0;
}
/** if ``values`` was not pre-allocated, do it now **/
if (!values) {
Py_ssize_t osize = self->cxx->shape();
values = (PyBlitzArrayObject*)PyBlitzArray_SimpleNew(NPY_FLOAT64, 1, &osize);
values_ = make_safe(values);
}
/** all basic checks are done, can call the machine now **/
int label = 0;
try {
auto bz = PyBlitzArrayCxx_AsBlitz<double,1>(values);
bool ok = self->cxx->read_(label, *bz);
if (!ok) Py_RETURN_NONE; ///< error condition or end-of-file
}
catch (std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return 0;
}
catch (...) {
PyErr_Format(PyExc_RuntimeError, "%s cannot read data: unknown exception caught", Py_TYPE(self)->tp_name);
return 0;
}
Py_INCREF(values);
return Py_BuildValue("iO",
label,
PyBlitzArray_NUMPY_WRAP(reinterpret_cast<PyObject*>(values))
);
}
PyDoc_STRVAR(s_read_all_str, "read_all");
PyDoc_STRVAR(s_read_all_doc,
"o.read_all([labels, [values]) -> (array, array)\n\
\n\
Reads all contents of the file into the output arrays\n\
``labels`` (used for storing each entry's label) and\n\
``values`` (used to store each entry's features).\n\
The array ``labels``, if provided, must be a 1D\n\
:py:class:`numpy.ndarray` with data type ``int64``,\n\
containing as many positions as entries in the file, as\n\
returned by the attribute :py:attr:`File.samples`. The\n\
array ``values``, if provided, must be a 2D array with\n\
data type ``float64``, as many rows as entries in the\n\
file and as many columns as features in each entry, as\n\
defined by the attribute :py:attr:`File.shape`.\n\
\n\
If the output arrays ``labels`` and/or ``values`` are not\n\
provided, they will be allocated internally and returned.\n\
\n\
.. note::\n\
\n\
This method is intended to be used for reading the\n\
whole contents of the input file. The file will be\n\
reset as by calling :py:meth:`File.reset` before the\n\
readout starts.\n\
\n\
");
static PyObject* PyBobLearnLibsvmFile_read_all
(PyBobLearnLibsvmFileObject* self, PyObject* args, PyObject* kwds) {
// before doing anything, check file status and returns if that is the case
if (!self->cxx->good()) Py_RETURN_NONE;
static const char* const_kwlist[] = {"labels", "values", 0};
static char** kwlist = const_cast<char**>(const_kwlist);
PyBlitzArrayObject* labels = 0;
PyBlitzArrayObject* values = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O&O&", kwlist,
&PyBlitzArray_OutputConverter, &labels,
&PyBlitzArray_OutputConverter, &values
)) return 0;
//protects acquired resources through this scope
auto labels_ = make_xsafe(labels);
auto values_ = make_xsafe(values);
if (labels && labels->type_num != NPY_INT64) {
PyErr_Format(PyExc_TypeError, "`%s' only supports 64-bit integer arrays for output array `labels'", Py_TYPE(self)->tp_name);
return 0;
}
if (values && values->type_num != NPY_FLOAT64) {
PyErr_Format(PyExc_TypeError, "`%s' only supports 64-bit float arrays for output array `values'", Py_TYPE(self)->tp_name);
return 0;
}
if (labels && labels->ndim != 1) {
PyErr_Format(PyExc_RuntimeError, "Output array `labels' should always be 1D but you provided an object with %" PY_FORMAT_SIZE_T "d dimensions", labels->ndim);
return 0;
}
if (values && values->ndim != 2) {
PyErr_Format(PyExc_RuntimeError, "Output array `values' should always be 2D but you provided an object with %" PY_FORMAT_SIZE_T "d dimensions", values->ndim);
return 0;
}
if (labels && labels->shape[0] != (Py_ssize_t)self->cxx->samples()) {
PyErr_Format(PyExc_RuntimeError, "1D `labels' array should have %" PY_FORMAT_SIZE_T "d elements matching the number of samples in this file, not %" PY_FORMAT_SIZE_T "d rows", self->cxx->samples(), labels->shape[0]);
return 0;
}
if (values && values->shape[0] != (Py_ssize_t)self->cxx->samples()) {
PyErr_Format(PyExc_RuntimeError, "2D `values' array should have %" PY_FORMAT_SIZE_T "d rows matching the number of samples in this file, not %" PY_FORMAT_SIZE_T "d rows", self->cxx->samples(), values->shape[0]);
return 0;
}
if (values && values->shape[1] != (Py_ssize_t)self->cxx->shape()) {
PyErr_Format(PyExc_RuntimeError, "2D `values' array should have %" PY_FORMAT_SIZE_T "d columns matching the shape of this file, not %" PY_FORMAT_SIZE_T "d rows", self->cxx->shape(), values->shape[0]);
return 0;
}
/** if ``labels`` was not pre-allocated, do it now **/
if (!labels) {
Py_ssize_t osize = self->cxx->samples();
labels = (PyBlitzArrayObject*)PyBlitzArray_SimpleNew(NPY_INT64, 1, &osize);
labels_ = make_safe(labels);
}
/** if ``values`` was not pre-allocated, do it now **/
if (!values) {
Py_ssize_t osize[2];
osize[0] = self->cxx->samples();
osize[1] = self->cxx->shape();
values = (PyBlitzArrayObject*)PyBlitzArray_SimpleNew(NPY_FLOAT64, 2, osize);
values_ = make_safe(values);
}
/** all basic checks are done, can call the machine now **/
try {
self->cxx->reset();
auto bzlab = PyBlitzArrayCxx_AsBlitz<int64_t,1>(labels);
auto bzval = PyBlitzArrayCxx_AsBlitz<double,2>(values);
blitz::Range all = blitz::Range::all();
int k = 0;
while (self->cxx->good()) {
blitz::Array<double,1> v_ = (*bzval)(k, all);
int label = 0;
bool ok = self->cxx->read_(label, v_);
if (ok) (*bzlab)(k) = label;
++k;
}
}
catch (std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return 0;
}
catch (...) {
PyErr_Format(PyExc_RuntimeError, "%s cannot read data: unknown exception caught", Py_TYPE(self)->tp_name);
return 0;
}
Py_INCREF(labels);
Py_INCREF(values);
return Py_BuildValue("OO",
PyBlitzArray_NUMPY_WRAP(reinterpret_cast<PyObject*>(labels)),
PyBlitzArray_NUMPY_WRAP(reinterpret_cast<PyObject*>(values))
);
}
static PyMethodDef PyBobLearnLibsvmFile_methods[] = {
{
s_reset_str,
......@@ -323,6 +539,18 @@ static PyMethodDef PyBobLearnLibsvmFile_methods[] = {
METH_NOARGS,
s_eof_doc
},
{
s_read_str,
(PyCFunction)PyBobLearnLibsvmFile_read,
METH_VARARGS|METH_KEYWORDS,
s_read_doc
},
{
s_read_all_str,
(PyCFunction)PyBobLearnLibsvmFile_read_all,
METH_VARARGS|METH_KEYWORDS,
s_read_all_doc
},
{0} /* Sentinel */
};
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment