diff --git a/bob/io/base/include/bob.io.base/blitz_array.h b/bob/io/base/include/bob.io.base/blitz_array.h index c592b7f3844aadbbc755a30a16de9e0ef4d5b05f..083ec1d3245d2b3bc1f92e821b71eaeffe1de0b3 100644 --- a/bob/io/base/include/bob.io.base/blitz_array.h +++ b/bob/io/base/include/bob.io.base/blitz_array.h @@ -244,7 +244,7 @@ namespace bob { namespace io { namespace base { namespace array { * data at. Only get the number of dimensions right! */ template <typename T, int N> blitz::Array<T,N> cast() const { - return bob::core::array::cast<T,N>(*this); + return bob::io::base::array::cast<T,N>(*this); } private: //representation diff --git a/bob/io/base/test.cpp b/bob/io/base/test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9dd3d956666033e0acc95aaf278eb8ec9b2d4ec5 --- /dev/null +++ b/bob/io/base/test.cpp @@ -0,0 +1,100 @@ +/** + * @author Manuel Gunther + * @date Tue Sep 13 13:01:31 MDT 2016 + * + * @brief Tests for bob::io::base + */ + +#include <bob.io.base/api.h> +#include <bob.blitz/cleanup.h> +#include <bob.extension/documentation.h> + +#include <boost/format.hpp> +#include <boost/filesystem.hpp> + +static auto s_test_api = bob::extension::FunctionDoc( + "_test_api", + "Some tests for API functions" +) +.add_prototype("tempdir") +.add_parameter("tempdir", "str", "A temporary directory to write data to") +; +static PyObject* _test_api(PyObject*, PyObject *args, PyObject* kwds){ +BOB_TRY + static char** kwlist = s_test_api.kwlist(); + + const char* tempdir; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "s", kwlist, &tempdir)) return 0; + + blitz::Array<uint8_t, 1> test_data(5); + for (int i = 0; i < 5; ++i){ + test_data(i) = i+1; + } + + auto h5file = bob::io::base::CodecRegistry::instance()->findByExtension(".hdf5"); + + boost::filesystem::path hdf5(tempdir); hdf5 /= std::string("test.h5"); + + auto output = h5file(hdf5.string().c_str(), 'w'); + output->write(test_data); + output.reset(); + + auto input = h5file(hdf5.string().c_str(), 'r'); + blitz::Array<uint8_t,1> read_data = input->read<uint8_t,1>(0); + + // Does not compile at the moment + blitz::Array<uint16_t,1> read_data_2 = input->cast<uint16_t,1>(0); + + input.reset(); + + if (blitz::any(test_data - read_data)) + throw std::runtime_error("The CSV IO test did not succeed"); + + Py_RETURN_NONE; +BOB_CATCH_FUNCTION("_test_api", 0) +} + +static PyMethodDef module_methods[] = { + { + s_test_api.name(), + (PyCFunction)_test_api, + METH_VARARGS|METH_KEYWORDS, + s_test_api.doc(), + }, + {0} /* Sentinel */ +}; + +PyDoc_STRVAR(module_docstr, "Tests for bob::io::base"); + +#if PY_VERSION_HEX >= 0x03000000 +static PyModuleDef module_definition = { + PyModuleDef_HEAD_INIT, + BOB_EXT_MODULE_NAME, + module_docstr, + -1, + module_methods, + 0, 0, 0, 0 +}; +#endif + +static PyObject* create_module (void) { + +# if PY_VERSION_HEX >= 0x03000000 + PyObject* m = PyModule_Create(&module_definition); + auto m_ = make_xsafe(m); + const char* ret = "O"; +# else + PyObject* m = Py_InitModule3(BOB_EXT_MODULE_NAME, module_methods, module_docstr); + const char* ret = "N"; +# endif + if (!m) return 0; + + return Py_BuildValue(ret, m); +} + +PyMODINIT_FUNC BOB_EXT_ENTRY_NAME (void) { +# if PY_VERSION_HEX >= 0x03000000 + return +# endif + create_module(); +} diff --git a/bob/io/base/test_cpp.py b/bob/io/base/test_cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..a9f4e06bea775e012ede089dbc9034918a9414c0 --- /dev/null +++ b/bob/io/base/test_cpp.py @@ -0,0 +1,11 @@ +from bob.io.base._test import _test_api + +import tempfile +import shutil + +def test_api(): + temp_dir = tempfile.mkdtemp() + try: + _test_api(temp_dir) + finally: + shutil.rmtree(temp_dir) diff --git a/setup.py b/setup.py index 610485d1dd77a6a24ae5f27ae7fb481663ff1fb5..08ab4e55623d44102c14c2ed9045203ff27a6c3c 100644 --- a/setup.py +++ b/setup.py @@ -227,6 +227,20 @@ setup( packages = packages, boost_modules = boost_modules, ), + + Extension("bob.io.base._test", + [ + "bob/io/base/test.cpp", + ], + library_dirs = library_dirs, + libraries = libraries, + define_macros = define_macros, + system_include_dirs = system_include_dirs, + version = version, + bob_packages = bob_packages, + packages = packages, + boost_modules = boost_modules, + ), ], cmdclass = {