From d2f2fd5e6814af3a0531513889e101a09f789f0e Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Sun, 10 Nov 2013 09:50:24 +0100 Subject: [PATCH] Implemented slicing through mapping protocol --- xbob/io/file.cpp | 62 +++++++++++++++++++++++----- xbob/io/test/test_file.py | 86 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 133 insertions(+), 15 deletions(-) diff --git a/xbob/io/file.cpp b/xbob/io/file.cpp index 77ab0aa..f551810 100644 --- a/xbob/io/file.cpp +++ b/xbob/io/file.cpp @@ -199,7 +199,9 @@ int PyBobIo_AsTypenum (bob::core::array::ElementType type) { } -static PyObject* PyBobIoFile_GetItem (PyBobIoFileObject* self, Py_ssize_t i) { +static PyObject* PyBobIoFile_GetIndex (PyBobIoFileObject* self, Py_ssize_t i) { + + if (i < 0) i += self->f->size(); ///< adjust for negative indexing if (i < 0 || i >= self->f->size()) { PyErr_Format(PyExc_IndexError, "file index out of range - `%s' only contains %" PY_FORMAT_SIZE_T "d object(s)", self->f->filename().c_str(), self->f->size()); @@ -236,12 +238,52 @@ static PyObject* PyBobIoFile_GetItem (PyBobIoFileObject* self, Py_ssize_t i) { } -static PySequenceMethods PyBobIoFile_Sequence = { - (lenfunc)PyBobIoFile_Len, - 0, /* concat */ - 0, /* repeat */ - (ssizeargfunc)PyBobIoFile_GetItem, - 0 /* slice */ +static PyObject* PyBobIoFile_GetSlice (PyBobIoFileObject* self, Py_ssize_t start, Py_ssize_t stop, Py_ssize_t step, Py_ssize_t length) { + + PyObject* retval = PyList_New(length); + if (!retval) return 0; + + Py_ssize_t counter = 0; + for (auto i = start; (start<=stop)?i<stop:i>stop; i+=step) { + + PyObject* item = PyBobIoFile_GetIndex(self, i); + if (!item) { + Py_DECREF(retval); + return 0; + } + + PyList_SET_ITEM(retval, counter++, item); + + } + + return retval; + +} + +static PyObject* PyBobIoFile_GetItem (PyBobIoFileObject* self, PyObject* item) { + if (PyIndex_Check(item)) { + Py_ssize_t i = PyNumber_AsSsize_t(item, PyExc_IndexError); + if (i == -1 && PyErr_Occurred()) return 0; + return PyBobIoFile_GetIndex(self, i); + } + if (PySlice_Check(item)) { + Py_ssize_t start, stop, step, slicelength; + if (PySlice_GetIndicesEx((PySliceObject*)item, self->f->size(), + &start, &stop, &step, &slicelength) < 0) return 0; + if (slicelength <= 0) return PyList_New(0); + return PyBobIoFile_GetSlice(self, start, stop, step, slicelength); + } + else { + PyErr_Format(PyExc_TypeError, "File indices must be integers, not %.200s", + item->ob_type->tp_name); + return 0; + } +} + +static PyMappingMethods PyBobIoFile_Mapping = { + (lenfunc)PyBobIoFile_Len, //mp_lenght + (binaryfunc)PyBobIoFile_GetItem, //mp_subscript + 0 /* (objobjargproc)PyBobIoFile_SetItem //mp_ass_subscript */ }; static PyObject* PyBobIoFile_Read(PyBobIoFileObject* self, PyObject *args, PyObject* kwds) { @@ -264,7 +306,7 @@ static PyObject* PyBobIoFile_Read(PyBobIoFileObject* self, PyObject *args, PyObj return 0; } - return PyBobIoFile_GetItem(self, i); + return PyBobIoFile_GetIndex(self, i); } @@ -538,8 +580,8 @@ PyTypeObject PyBobIoFile_Type = { 0, /*tp_compare*/ (reprfunc)PyBobIoFile_Repr, /*tp_repr*/ 0, /*tp_as_number*/ - &PyBobIoFile_Sequence, /*tp_as_sequence*/ - 0, /*tp_as_mapping*/ + 0, /*tp_as_sequence*/ + &PyBobIoFile_Mapping, /*tp_as_mapping*/ 0, /*tp_hash */ 0, /*tp_call*/ (reprfunc)PyBobIoFile_Repr, /*tp_str*/ diff --git a/xbob/io/test/test_file.py b/xbob/io/test/test_file.py index 0fa83f9..6046492 100644 --- a/xbob/io/test/test_file.py +++ b/xbob/io/test/test_file.py @@ -17,6 +17,87 @@ import nose.tools from .. import load, write, peek, peek_all, File from . import utils as testutils +def test_peek(): + + f = testutils.datafile('test1.hdf5', __name__) + assert peek(f) == (numpy.uint16, (3,), (1,)) + assert peek_all(f) == (numpy.uint16, (3,3), (3,1)) + +def test_indexing(): + + f = File(testutils.datafile('matlab_2d.hdf5', __name__), 'r') + nose.tools.eq_(len(f), 512) + + objs = f[:] + nose.tools.eq_(len(f), len(objs)) + obj0 = f[0] + obj1 = f[1] + + # simple indexing + assert numpy.allclose(objs[0], obj0) + assert numpy.allclose(objs[1], obj1) + assert numpy.allclose(f[len(f)-1], f[-1]) + assert numpy.allclose(f[len(f)-2], f[-2]) + + # get slice + s1 = f[3:10:2] + nose.tools.eq_(len(s1), 4) + assert numpy.allclose(s1[0], objs[3]) + assert numpy.allclose(s1[1], objs[5]) + assert numpy.allclose(s1[2], objs[7]) + assert numpy.allclose(s1[3], objs[9]) + + # get negative slicing + s2 = f[-10:-2:3] + nose.tools.eq_(len(s2), 3) + assert numpy.allclose(s2[0], f[len(f)-10]) + assert numpy.allclose(s2[1], f[len(f)-7]) + assert numpy.allclose(s2[2], f[len(f)-4]) + + # get negative stepping slice + s3 = f[20:10:-3] + nose.tools.eq_(len(s3), 4) + assert numpy.allclose(s3[0], f[20]) + assert numpy.allclose(s3[1], f[17]) + assert numpy.allclose(s3[2], f[14]) + assert numpy.allclose(s3[3], f[11]) + + # get negative indexing and positive stepping slice + s4 = f[-20:-10:3] + nose.tools.eq_(len(s4), 4) + assert numpy.allclose(s4[0], f[len(f)-20]) + assert numpy.allclose(s4[1], f[len(f)-17]) + assert numpy.allclose(s4[2], f[len(f)-14]) + assert numpy.allclose(s4[3], f[len(f)-11]) + + # get all negative slice + s5 = f[-10:-20:-3] + nose.tools.eq_(len(s5), 4) + assert numpy.allclose(s5[0], f[len(f)-10]) + assert numpy.allclose(s5[1], f[len(f)-13]) + assert numpy.allclose(s5[2], f[len(f)-16]) + assert numpy.allclose(s5[3], f[len(f)-19]) + +@nose.tools.raises(TypeError) +def test_indexing_type_check(): + + f = File(testutils.datafile('matlab_2d.hdf5', __name__), 'r') + nose.tools.eq_(len(f), 512) + f[4.5] + +@nose.tools.raises(IndexError) +def test_indexing_boundaries(): + + f = File(testutils.datafile('matlab_2d.hdf5', __name__), 'r') + nose.tools.eq_(len(f), 512) + f[512] + +@nose.tools.raises(IndexError) +def test_indexing_negative_boundaries(): + f = File(testutils.datafile('matlab_2d.hdf5', __name__), 'r') + nose.tools.eq_(len(f), 512) + f[-513] + def transcode(filename): """Runs a complete transcoding test, to and from the binary format.""" @@ -273,8 +354,3 @@ def test_csv(): arrayset_readwrite(".csv", a2, close=True) arrayset_readwrite('.csv', a3, close=True) -def test_peek(): - - f = testutils.datafile('test1.hdf5', __name__) - assert peek(f) == (numpy.uint16, (3,), (1,)) - assert peek_all(f) == (numpy.uint16, (3,3), (3,1)) -- GitLab