diff --git a/bob/learn/em/gmm_machine.cpp b/bob/learn/em/gmm_machine.cpp index 3d960b9e62e444e690cf4a439a92cc6a4ed11cf3..8e2d7caa45348676ee25860b57467e19a819ccec 100644 --- a/bob/learn/em/gmm_machine.cpp +++ b/bob/learn/em/gmm_machine.cpp @@ -766,11 +766,10 @@ static PyObject* PyBobLearnEMGMMMachine_accStatistics(PyBobLearnEMGMMMachineObje //protects acquired resources through this scope auto input_ = make_safe(input); - blitz::Array<double,2> blitz_test = *PyBlitzArrayCxx_AsBlitz<double,2>(input); - if (blitz_test.extent(1)==0) + if (input->ndim == 1) self->cxx->accStatistics(*PyBlitzArrayCxx_AsBlitz<double,1>(input), *stats->cxx); else - self->cxx->accStatistics(blitz_test, *stats->cxx); + self->cxx->accStatistics(*PyBlitzArrayCxx_AsBlitz<double,2>(input), *stats->cxx); BOB_CATCH_MEMBER("cannot accumulate the statistics", 0) @@ -803,11 +802,10 @@ static PyObject* PyBobLearnEMGMMMachine_accStatistics_(PyBobLearnEMGMMMachineObj //protects acquired resources through this scope auto input_ = make_safe(input); - blitz::Array<double,2> blitz_test = *PyBlitzArrayCxx_AsBlitz<double,2>(input); - if (blitz_test.extent(1)==0) + if (input->ndim==1) self->cxx->accStatistics_(*PyBlitzArrayCxx_AsBlitz<double,1>(input), *stats->cxx); else - self->cxx->accStatistics_(blitz_test, *stats->cxx); + self->cxx->accStatistics_(*PyBlitzArrayCxx_AsBlitz<double,2>(input), *stats->cxx); BOB_CATCH_MEMBER("cannot accumulate the statistics", 0) Py_RETURN_NONE;