From d2f2fd5e6814af3a0531513889e101a09f789f0e Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Sun, 10 Nov 2013 09:50:24 +0100
Subject: [PATCH] Implemented slicing through mapping protocol

---
 xbob/io/file.cpp          | 62 +++++++++++++++++++++++-----
 xbob/io/test/test_file.py | 86 ++++++++++++++++++++++++++++++++++++---
 2 files changed, 133 insertions(+), 15 deletions(-)

diff --git a/xbob/io/file.cpp b/xbob/io/file.cpp
index 77ab0aa..f551810 100644
--- a/xbob/io/file.cpp
+++ b/xbob/io/file.cpp
@@ -199,7 +199,9 @@ int PyBobIo_AsTypenum (bob::core::array::ElementType type) {
 
 }
 
-static PyObject* PyBobIoFile_GetItem (PyBobIoFileObject* self, Py_ssize_t i) {
+static PyObject* PyBobIoFile_GetIndex (PyBobIoFileObject* self, Py_ssize_t i) {
+
+  if (i < 0) i += self->f->size(); ///< adjust for negative indexing
 
   if (i < 0 || i >= self->f->size()) {
     PyErr_Format(PyExc_IndexError, "file index out of range - `%s' only contains %" PY_FORMAT_SIZE_T "d object(s)", self->f->filename().c_str(), self->f->size());
@@ -236,12 +238,52 @@ static PyObject* PyBobIoFile_GetItem (PyBobIoFileObject* self, Py_ssize_t i) {
 
 }
 
-static PySequenceMethods PyBobIoFile_Sequence = {
-    (lenfunc)PyBobIoFile_Len,
-    0, /* concat */
-    0, /* repeat */
-    (ssizeargfunc)PyBobIoFile_GetItem,
-    0 /* slice */
+static PyObject* PyBobIoFile_GetSlice (PyBobIoFileObject* self, Py_ssize_t start, Py_ssize_t stop, Py_ssize_t step, Py_ssize_t length) {
+
+  PyObject* retval = PyList_New(length);
+  if (!retval) return 0;
+
+  Py_ssize_t counter = 0;
+  for (auto i = start; (start<=stop)?i<stop:i>stop; i+=step) {
+
+    PyObject* item = PyBobIoFile_GetIndex(self, i);
+    if (!item) {
+      Py_DECREF(retval);
+      return 0;
+    }
+
+    PyList_SET_ITEM(retval, counter++, item);
+
+  }
+
+  return retval;
+
+}
+
+static PyObject* PyBobIoFile_GetItem (PyBobIoFileObject* self, PyObject* item) {
+   if (PyIndex_Check(item)) {
+     Py_ssize_t i = PyNumber_AsSsize_t(item, PyExc_IndexError);
+     if (i == -1 && PyErr_Occurred()) return 0;
+     return PyBobIoFile_GetIndex(self, i);
+   }
+   if (PySlice_Check(item)) {
+     Py_ssize_t start, stop, step, slicelength;
+     if (PySlice_GetIndicesEx((PySliceObject*)item, self->f->size(), 
+           &start, &stop, &step, &slicelength) < 0) return 0;
+     if (slicelength <= 0) return PyList_New(0);
+     return PyBobIoFile_GetSlice(self, start, stop, step, slicelength);
+   }
+   else {
+     PyErr_Format(PyExc_TypeError, "File indices must be integers, not %.200s",
+         item->ob_type->tp_name);
+     return 0;
+   }
+}
+
+static PyMappingMethods PyBobIoFile_Mapping = {
+    (lenfunc)PyBobIoFile_Len, //mp_lenght
+    (binaryfunc)PyBobIoFile_GetItem, //mp_subscript
+    0 /* (objobjargproc)PyBobIoFile_SetItem //mp_ass_subscript */
 };
 
 static PyObject* PyBobIoFile_Read(PyBobIoFileObject* self, PyObject *args, PyObject* kwds) {
@@ -264,7 +306,7 @@ static PyObject* PyBobIoFile_Read(PyBobIoFileObject* self, PyObject *args, PyObj
       return 0;
     }
 
-    return PyBobIoFile_GetItem(self, i);
+    return PyBobIoFile_GetIndex(self, i);
 
   }
 
@@ -538,8 +580,8 @@ PyTypeObject PyBobIoFile_Type = {
     0,                                          /*tp_compare*/
     (reprfunc)PyBobIoFile_Repr,                 /*tp_repr*/
     0,                                          /*tp_as_number*/
-    &PyBobIoFile_Sequence,                      /*tp_as_sequence*/
-    0,                                          /*tp_as_mapping*/
+    0,                                          /*tp_as_sequence*/
+    &PyBobIoFile_Mapping,                       /*tp_as_mapping*/
     0,                                          /*tp_hash */
     0,                                          /*tp_call*/
     (reprfunc)PyBobIoFile_Repr,                 /*tp_str*/
diff --git a/xbob/io/test/test_file.py b/xbob/io/test/test_file.py
index 0fa83f9..6046492 100644
--- a/xbob/io/test/test_file.py
+++ b/xbob/io/test/test_file.py
@@ -17,6 +17,87 @@ import nose.tools
 from .. import load, write, peek, peek_all, File
 from . import utils as testutils
 
+def test_peek():
+  
+  f = testutils.datafile('test1.hdf5', __name__)
+  assert peek(f) == (numpy.uint16, (3,), (1,))
+  assert peek_all(f) == (numpy.uint16, (3,3), (3,1))
+
+def test_indexing():
+  
+  f = File(testutils.datafile('matlab_2d.hdf5', __name__), 'r')
+  nose.tools.eq_(len(f), 512)
+
+  objs = f[:]
+  nose.tools.eq_(len(f), len(objs))
+  obj0 = f[0]
+  obj1 = f[1]
+
+  # simple indexing
+  assert numpy.allclose(objs[0], obj0)
+  assert numpy.allclose(objs[1], obj1)
+  assert numpy.allclose(f[len(f)-1], f[-1])
+  assert numpy.allclose(f[len(f)-2], f[-2])
+
+  # get slice
+  s1 = f[3:10:2]
+  nose.tools.eq_(len(s1), 4)
+  assert numpy.allclose(s1[0], objs[3])
+  assert numpy.allclose(s1[1], objs[5])
+  assert numpy.allclose(s1[2], objs[7])
+  assert numpy.allclose(s1[3], objs[9])
+
+  # get negative slicing
+  s2 = f[-10:-2:3]
+  nose.tools.eq_(len(s2), 3)
+  assert numpy.allclose(s2[0], f[len(f)-10])
+  assert numpy.allclose(s2[1], f[len(f)-7])
+  assert numpy.allclose(s2[2], f[len(f)-4])
+
+  # get negative stepping slice
+  s3 = f[20:10:-3]
+  nose.tools.eq_(len(s3), 4)
+  assert numpy.allclose(s3[0], f[20])
+  assert numpy.allclose(s3[1], f[17])
+  assert numpy.allclose(s3[2], f[14])
+  assert numpy.allclose(s3[3], f[11])
+
+  # get negative indexing and positive stepping slice
+  s4 = f[-20:-10:3]
+  nose.tools.eq_(len(s4), 4)
+  assert numpy.allclose(s4[0], f[len(f)-20])
+  assert numpy.allclose(s4[1], f[len(f)-17])
+  assert numpy.allclose(s4[2], f[len(f)-14])
+  assert numpy.allclose(s4[3], f[len(f)-11])
+
+  # get all negative slice
+  s5 = f[-10:-20:-3]
+  nose.tools.eq_(len(s5), 4)
+  assert numpy.allclose(s5[0], f[len(f)-10])
+  assert numpy.allclose(s5[1], f[len(f)-13])
+  assert numpy.allclose(s5[2], f[len(f)-16])
+  assert numpy.allclose(s5[3], f[len(f)-19])
+
+@nose.tools.raises(TypeError)
+def test_indexing_type_check():
+  
+  f = File(testutils.datafile('matlab_2d.hdf5', __name__), 'r')
+  nose.tools.eq_(len(f), 512)
+  f[4.5]
+
+@nose.tools.raises(IndexError)
+def test_indexing_boundaries():
+  
+  f = File(testutils.datafile('matlab_2d.hdf5', __name__), 'r')
+  nose.tools.eq_(len(f), 512)
+  f[512]
+
+@nose.tools.raises(IndexError)
+def test_indexing_negative_boundaries():
+  f = File(testutils.datafile('matlab_2d.hdf5', __name__), 'r')
+  nose.tools.eq_(len(f), 512)
+  f[-513]
+
 def transcode(filename):
   """Runs a complete transcoding test, to and from the binary format."""
 
@@ -273,8 +354,3 @@ def test_csv():
   arrayset_readwrite(".csv", a2, close=True)
   arrayset_readwrite('.csv', a3, close=True)
 
-def test_peek():
-  
-  f = testutils.datafile('test1.hdf5', __name__)
-  assert peek(f) == (numpy.uint16, (3,), (1,))
-  assert peek_all(f) == (numpy.uint16, (3,3), (3,1))
-- 
GitLab