main.cpp 3.61 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
// include directly and indirectly dependent libraries
#ifdef NO_IMPORT_ARRAY
#undef NO_IMPORT_ARRAY
#endif


#include <bob.blitz/cppapi.h>
#include <bob.blitz/cleanup.h>
#include <bob.extension/documentation.h>


12 13 14 15 16 17
// declare C++ function
void remove_highlights(   blitz::Array<float ,3> &img,
                          blitz::Array<float ,3> &diff,
                          blitz::Array<float ,3> &sfi,
                          blitz::Array<float ,3> &residue,
                          float  epsilon);
18 19 20 21 22 23 24 25 26 27

// declare the function
static PyObject* PyRemoveHighlights(PyObject*, PyObject* args, PyObject* kwargs) {

  BOB_TRY

  static const char* const_kwlist[] = {"array", "startEps", 0};
  static char** kwlist = const_cast<char**>(const_kwlist);

  PyBlitzArrayObject* array;
28
  double epsilon  = 0.5f;
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&|d", kwlist,
                                  &PyBlitzArray_Converter, &array,
                                  &epsilon)) return 0;

  // check that the array has the expected properties
  if (array->type_num != NPY_FLOAT32|| array->ndim != 3){
    PyErr_Format(PyExc_TypeError,
                "remove_highlights : Only 3D arrays of type float32 are allowed");
    return 0;
  }

  // extract the actual blitz array from the Python type
  blitz::Array<float ,3> img = *PyBlitzArrayCxx_AsBlitz<float , 3>(array);

  // results
  int dim_x = img.shape()[2];
  int dim_y = img.shape()[1];

  blitz::Array<float ,3> diffuse_img(3, dim_y, dim_x);
  blitz::Array<float ,3> speckle_free_img(3, dim_y, dim_x);
  blitz::Array<float ,3> speckle_img(3, dim_y, dim_x);

  diffuse_img       = 0;
  speckle_free_img  = 0;
  speckle_img       = 0;

  // call the C++ function
57
  remove_highlights(img, diffuse_img, speckle_free_img, speckle_img, (float)epsilon);
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128

  // convert the blitz array back to numpy and return it
  PyObject *ret_tuple = PyTuple_New(3);
  PyTuple_SetItem(ret_tuple, 0, PyBlitzArrayCxx_AsNumpy(speckle_free_img));
  PyTuple_SetItem(ret_tuple, 1, PyBlitzArrayCxx_AsNumpy(diffuse_img));
  PyTuple_SetItem(ret_tuple, 2, PyBlitzArrayCxx_AsNumpy(speckle_img));

  return ret_tuple;

  // handle exceptions that occurred in this function
  BOB_CATCH_FUNCTION("remove_highlights", 0)
}


//////////////////////////////////////////////////////////////////////////
/////// Python module declaration ////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////

// module-wide methods
static PyMethodDef module_methods[] = {
  {
    "remove_highlights",
    (PyCFunction)PyRemoveHighlights,
    METH_VARARGS|METH_KEYWORDS,
    "remove_highlights [doc]"
  },
  {NULL}  // Sentinel
};

// module documentation
PyDoc_STRVAR(module_docstr, "Exemplary Python Bindings");

// module definition
#if PY_VERSION_HEX >= 0x03000000
static PyModuleDef module_definition = {
  PyModuleDef_HEAD_INIT,
  BOB_EXT_MODULE_NAME,
  module_docstr,
  -1,
  module_methods,
  0, 0, 0, 0
};
#endif

// create the module
static PyObject* create_module (void) {

# if PY_VERSION_HEX >= 0x03000000
  PyObject* module = PyModule_Create(&module_definition);
  auto module_ = make_xsafe(module);
  const char* ret = "O";
# else
  PyObject* module = Py_InitModule3(BOB_EXT_MODULE_NAME, module_methods, module_docstr);
  const char* ret = "N";
# endif
  if (!module) return 0;

  if (PyModule_AddStringConstant(module, "__version__", BOB_EXT_MODULE_VERSION) < 0) return 0;

  /* imports bob.blitz C-API + dependencies */
  if (import_bob_blitz() < 0) return 0;

  return Py_BuildValue(ret, module);
}

PyMODINIT_FUNC BOB_EXT_ENTRY_NAME (void) {
# if PY_VERSION_HEX >= 0x03000000
  return
# endif
    create_module();
}