scatter.cpp 9.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/**
 * @date Mon Jun 20 11:47:58 2011 +0200
 * @author Andre Anjos <andre.anjos@idiap.ch>
 *
 * @brief Python bindings to statistical methods
 *
 * Copyright (C) 2011-2013 Idiap Research Institute, Martigny, Switzerland
 */

#include "scatter.h"
André Anjos's avatar
André Anjos committed
11
12
#include <bob.blitz/cppapi.h>
#include <bob.blitz/cleanup.h>
13
#include <bob.math/stats.h>
14
15
16
17
18
19
20
21
22
23
24
25

PyObject* py_scatter (PyObject*, PyObject* args, PyObject* kwds) {

  /* Parses input arguments in a single shot */
  static const char* const_kwlist[] = { "a", "s", "m", 0 /* Sentinel */ };
  static char** kwlist = const_cast<char**>(const_kwlist);

  PyBlitzArrayObject* a = 0;
  PyBlitzArrayObject* s = 0;
  PyBlitzArrayObject* m = 0;

  if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|O&O&",
26
        kwlist,
27
28
29
30
31
        &PyBlitzArray_Converter, &a,
        &PyBlitzArray_OutputConverter, &s,
        &PyBlitzArray_OutputConverter, &m
        )) return 0;

32
33
34
35
36
  //protects acquired resources through this scope
  auto a_ = make_safe(a);
  auto s_ = make_xsafe(s);
  auto m_ = make_xsafe(m);

37
38
39
40
41
42
43
44
  // basic checks
  if (a->ndim != 2 || (a->type_num != NPY_FLOAT32 && a->type_num != NPY_FLOAT64)) {
    PyErr_SetString(PyExc_TypeError, "input data matrix `a' should be either a 32 or 64-bit float 2D array");
    return 0;
  }

  if (s && (s->ndim != 2 || (s->type_num != a->type_num))) {
    PyErr_SetString(PyExc_TypeError, "output data matrix `s' should be either a 32 or 64-bit float 2D array, matching the data type of `a'");
45
    return 0;
46
47
48
49
  }

  if (m && (m->ndim != 1 || (m->type_num != a->type_num))) {
    PyErr_SetString(PyExc_TypeError, "output data vector `m' should be either a 32 or 64-bit float 1D array, matching the data type of `a'");
50
    return 0;
51
52
53
54
55
56
57
  }

  // allocates data not passed by the user
  bool user_s = s;
  if (!s) {
    Py_ssize_t sshape[2] = {a->shape[1], a->shape[1]};
    s = (PyBlitzArrayObject*)PyBlitzArray_SimpleNew(a->type_num, 2, sshape);
58
    s_ = make_safe(s);
59
60
61
  }

  bool user_m = m;
62
63
64
65
  if (!m) {
    m = (PyBlitzArrayObject*)PyBlitzArray_SimpleNew(a->type_num, 1, &a->shape[1]);
    m_ = make_safe(m);
  }
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

  try {
    switch (a->type_num) {
      case NPY_FLOAT32:
        bob::math::scatter(
            *PyBlitzArrayCxx_AsBlitz<float,2>(a),
            *PyBlitzArrayCxx_AsBlitz<float,2>(s),
            *PyBlitzArrayCxx_AsBlitz<float,1>(m)
            );
        break;

      case NPY_FLOAT64:
        bob::math::scatter(
            *PyBlitzArrayCxx_AsBlitz<double,2>(a),
            *PyBlitzArrayCxx_AsBlitz<double,2>(s),
            *PyBlitzArrayCxx_AsBlitz<double,1>(m)
            );
        break;

      default:
        PyErr_Format(PyExc_TypeError, "scatter calculation currently not implemented for type '%s'", PyBlitzArray_TypenumAsString(a->type_num));
87
        return 0;
88
89
90
91
    }
  }
  catch (std::exception& e) {
    PyErr_SetString(PyExc_RuntimeError, e.what());
92
    return 0;
93
94
95
96
97
98
  }
  catch (...) {
    PyErr_SetString(PyExc_RuntimeError, "scatter calculation failed: unknown exception caught");
    return 0;
  }

99
100
101
102
103
104
105
106
107
108
109
110
  int returns = 2 - (user_s + user_m);

  PyObject* retval = PyTuple_New(returns);

  // fill from the back
  if (!user_m)
    PyTuple_SET_ITEM(retval, --returns, PyBlitzArray_NUMPY_WRAP(Py_BuildValue("O", m)));

  if (!user_s)
    PyTuple_SET_ITEM(retval, --returns, PyBlitzArray_NUMPY_WRAP(Py_BuildValue("O", s)));

  return retval;
111
112
113
114
115
116
117
118
119
120
121
122
123

}


/**
 * Converts the input iterable d into a tuple of PyBlitzArrayObject's. Checks
 * each array is 2D and of type NPY_FLOAT32 or NPY_FLOAT64, consistently.
 * Returns 0 if a problem occurs, 1 otherwise.
 */
int BzTuple_Converter(PyObject* o, PyObject** a) {

  PyObject* tmp = PySequence_Tuple(o);
  if (!tmp) return 0;
124
  auto tmp_ = make_safe(tmp);
125
126
127
128
129
130
131
132
133
134
135

  if (PyTuple_GET_SIZE(o) < 2) {
    PyErr_SetString(PyExc_TypeError, "input data object must be a sequence or iterable with at least 2 2D arrays with 32 or 64-bit floats");
    return 0;
  }

  PyBlitzArrayObject* first = 0;
  int status = PyBlitzArray_Converter(PyTuple_GET_ITEM(tmp, 0), &first);
  if (!status) {
    return 0;
  }
136
  auto first_ = make_safe(first);
137

138
  if (first->ndim != 2 ||
139
140
141
142
143
144
145
146
      (first->type_num != NPY_FLOAT32 && first->type_num != NPY_FLOAT64)) {
    PyErr_SetString(PyExc_TypeError, "input data object must be a sequence or iterable with at least 2 2D arrays with 32 or 64-bit floats - the first array does not conform");
  }

  PyObject* retval = PyTuple_New(PyTuple_GET_SIZE(tmp));
  if (!retval) {
    return 0;
  }
147
  auto retval_ = make_safe(retval);
148

149
  PyTuple_SET_ITEM(retval, 0, Py_BuildValue("O", first));
150
151
152
153
154
155
156
157
158

  for (Py_ssize_t i=1; i<PyTuple_GET_SIZE(tmp); ++i) {

    PyBlitzArrayObject* next = 0;
    PyObject* item = PyTuple_GET_ITEM(tmp, i); //borrowed
    int status = PyBlitzArray_Converter(item, &next);
    if (!status) {
      return 0;
    }
159
    auto next_ = make_safe(next);
160
161
162
163
164
165
166
167
168
    if (next->type_num != first->type_num) {
        PyErr_Format(PyExc_TypeError, "array at data[%" PY_FORMAT_SIZE_T "d] does not have the same data type as the first array on the sequence (%s != %s)", i, PyBlitzArray_TypenumAsString(next->type_num), PyBlitzArray_TypenumAsString(first->type_num));
        return 0;
    }
    if (next->ndim != 2) {
        PyErr_Format(PyExc_TypeError, "array at data[%" PY_FORMAT_SIZE_T "d] does not have two dimensions, but %" PY_FORMAT_SIZE_T "d", i, next->ndim);
        return 0;
    }

169
    PyTuple_SET_ITEM(retval, i, Py_BuildValue("O",next));
170
171

  }
172
  *a = Py_BuildValue("O", retval);
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188

  return 1;

}

PyObject* py_scatters (PyObject*, PyObject* args, PyObject* kwds) {

  /* Parses input arguments in a single shot */
  static const char* const_kwlist[] = { "data", "sw", "sb", "m", 0 /* Sentinel */ };
  static char** kwlist = const_cast<char**>(const_kwlist);

  PyObject* data = 0;
  PyBlitzArrayObject* sw = 0;
  PyBlitzArrayObject* sb = 0;
  PyBlitzArrayObject* m = 0;

189
  if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|O&O&O&", kwlist,
190
191
192
193
194
195
        &BzTuple_Converter, &data,
        &PyBlitzArray_OutputConverter, &sw,
        &PyBlitzArray_OutputConverter, &sb,
        &PyBlitzArray_OutputConverter, &m
        )) return 0;

196
197
198
199
200
201
  //protects acquired resources through this scope
  auto data_ = make_safe(data);
  auto sw_ = make_xsafe(sw);
  auto sb_ = make_xsafe(sb);
  auto m_ = make_xsafe(m);

202
203
204
205
  PyBlitzArrayObject* first = (PyBlitzArrayObject*)PyTuple_GET_ITEM(data, 0);

  if (sw && (sw->ndim != 2 || (sw->type_num != first->type_num))) {
    PyErr_SetString(PyExc_TypeError, "output data matrix `sw' should be either a 32 or 64-bit float 2D array, matching the data type of `data'");
206
    return 0;
207
208
209
210
  }

  if (sb && (sb->ndim != 2 || (sb->type_num != first->type_num))) {
    PyErr_SetString(PyExc_TypeError, "output data matrix `sb' should be either a 32 or 64-bit float 2D array, matching the data type of `data'");
211
    return 0;
212
213
214
215
  }

  if (m && (m->ndim != 1 || (m->type_num != first->type_num))) {
    PyErr_SetString(PyExc_TypeError, "output data vector `m' should be either a 32 or 64-bit float 1D array, matching the data type of `data'");
216
    return 0;
217
218
219
220
221
222
223
  }

  // allocates data not passed by the user
  bool user_sw = sw;
  if (!sw) {
    Py_ssize_t sshape[2] = {first->shape[1], first->shape[1]};
    sw = (PyBlitzArrayObject*)PyBlitzArray_SimpleNew(first->type_num, 2, sshape);
224
    sw_ = make_safe(sw);
225
226
227
228
229
230
  }

  bool user_sb = sb;
  if (!sb) {
    Py_ssize_t sshape[2] = {first->shape[1], first->shape[1]};
    sb = (PyBlitzArrayObject*)PyBlitzArray_SimpleNew(first->type_num, 2, sshape);
231
    sb_ = make_safe(sb);
232
233
234
  }

  bool user_m = m;
235
236
237
238
  if (!m) {
    m = (PyBlitzArrayObject*)PyBlitzArray_SimpleNew(first->type_num, 1, &first->shape[1]);
    m_ = make_safe(m);
  }
239
240
241
242
243
244

  try {
    switch (first->type_num) {
      case NPY_FLOAT32:
        {
          std::vector<blitz::Array<float,2>> cxxdata;
245
          for (Py_ssize_t i=0; i<PyTuple_GET_SIZE(data); ++i) {
246
247
248
249
250
251
252
253
254
255
256
257
258
259
            cxxdata.push_back(*PyBlitzArrayCxx_AsBlitz<float,2>
                ((PyBlitzArrayObject*)PyTuple_GET_ITEM(data,i)));
            bob::math::scatters(cxxdata,
                *PyBlitzArrayCxx_AsBlitz<float,2>(sw),
                *PyBlitzArrayCxx_AsBlitz<float,2>(sb),
                *PyBlitzArrayCxx_AsBlitz<float,1>(m)
                );
          }
        }
        break;

      case NPY_FLOAT64:
        {
          std::vector<blitz::Array<double,2>> cxxdata;
260
          for (Py_ssize_t i=0; i<PyTuple_GET_SIZE(data); ++i) {
261
262
263
264
265
266
267
268
269
270
271
272
273
            cxxdata.push_back(*PyBlitzArrayCxx_AsBlitz<double,2>
                ((PyBlitzArrayObject*)PyTuple_GET_ITEM(data,i)));
            bob::math::scatters(cxxdata,
                *PyBlitzArrayCxx_AsBlitz<double,2>(sw),
                *PyBlitzArrayCxx_AsBlitz<double,2>(sb),
                *PyBlitzArrayCxx_AsBlitz<double,1>(m)
                );
          }
        }
        break;

      default:
        PyErr_Format(PyExc_TypeError, "scatters calculation currently not implemented for type '%s'", PyBlitzArray_TypenumAsString(first->type_num));
274
        return 0;
275
276
277
278
    }
  }
  catch (std::exception& e) {
    PyErr_SetString(PyExc_RuntimeError, e.what());
279
    return 0;
280
281
282
283
284
285
  }
  catch (...) {
    PyErr_SetString(PyExc_RuntimeError, "scatters calculation failed: unknown exception caught");
    return 0;
  }

286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
  int returns = 3 - (user_sw + user_sb + user_m);

  PyObject* retval = PyTuple_New(returns);

  // fill from the back
  if (!user_m)
    PyTuple_SET_ITEM(retval, --returns, PyBlitzArray_NUMPY_WRAP(Py_BuildValue("O", m)));

  if (!user_sb)
    PyTuple_SET_ITEM(retval, --returns, PyBlitzArray_NUMPY_WRAP(Py_BuildValue("O", sb)));

  if (!user_sw)
    PyTuple_SET_ITEM(retval, --returns, PyBlitzArray_NUMPY_WRAP(Py_BuildValue("O", sw)));

  return retval;
301
302

}