Commit e303670e authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Binding linearScoring

parent 0be70faf
......@@ -47,6 +47,10 @@ static int extract_gmmmachine_list(PyObject *list,
template <int N>
int extract_array_list(PyObject* list, std::vector<blitz::Array<double,N> >& vec)
{
if(list==0)
return 0;
for (int i=0; i<PyList_GET_SIZE(list); i++)
{
PyBlitzArrayObject* blitz_object;
......@@ -71,13 +75,13 @@ static auto linear_scoring = bob::extension::FunctionDoc(
0,
true
)
.add_prototype("models, ubm_mean, ubm_variance, test_stats, test_channelOffset, frame_length_normalisation", "output")
.add_prototype("models, ubm, test_stats, test_channelOffset, frame_length_normalisation", "output")
.add_parameter("models", "", "")
.add_parameter("ubm", "", "")
.add_parameter("test_stats", "", "")
.add_parameter("test_channelOffset", "", "")
.add_parameter("frame_length_normalisation", "bool", "")
.add_return("output","array_like<float,2>","Score");
.add_return("output","array_like<float,1>","Score");
static PyObject* PyBobLearnMisc_linear_scoring(PyObject*, PyObject* args, PyObject* kwargs) {
char** kwlist = linear_scoring.kwlist(0);
......@@ -85,51 +89,48 @@ static PyObject* PyBobLearnMisc_linear_scoring(PyObject*, PyObject* args, PyObje
//Cheking the number of arguments
int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
switch(nargs){
//Read a list of GMM
case 5:{
if((nargs >= 3) && (nargs<=5)){
PyObject* gmm_list_o = 0;
PyBobLearnMiscGMMMachineObject* ubm = 0;
PyObject* stats_list_o = 0;
PyObject* channel_offset_list_o = 0;
PyObject* frame_length_normalisation = 0;
PyObject* gmm_list_o = 0;
PyBobLearnMiscGMMMachineObject* ubm = 0;
PyObject* stats_list_o = 0;
PyObject* channel_offset_list_o = 0;
PyObject* frame_length_normalisation = Py_False;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!O!O!O!", kwlist, &PyList_Type, &gmm_list_o,
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!O!|O!O!", kwlist, &PyList_Type, &gmm_list_o,
&PyBobLearnMiscGMMMachine_Type, &ubm,
&PyList_Type, &stats_list_o,
&PyList_Type, &channel_offset_list_o,
&PyBool_Type, frame_length_normalisation)){
linear_scoring.print_usage();
Py_RETURN_NONE;
}
std::vector<boost::shared_ptr<const bob::learn::misc::GMMStats> > stats_list;
if(extract_gmmstats_list(stats_list_o ,stats_list)!=0)
Py_RETURN_NONE;
&PyBool_Type, &frame_length_normalisation)){
linear_scoring.print_usage();
Py_RETURN_NONE;
}
std::vector<boost::shared_ptr<const bob::learn::misc::GMMMachine> > gmm_list;
if(extract_gmmmachine_list(gmm_list_o ,gmm_list)!=0)
Py_RETURN_NONE;
std::vector<boost::shared_ptr<const bob::learn::misc::GMMStats> > stats_list;
if(extract_gmmstats_list(stats_list_o ,stats_list)!=0)
Py_RETURN_NONE;
std::vector<blitz::Array<double,2> > channel_offset_list;
if(extract_array_list(channel_offset_list_o ,channel_offset_list)!=0)
Py_RETURN_NONE;
std::vector<boost::shared_ptr<const bob::learn::misc::GMMMachine> > gmm_list;
if(extract_gmmmachine_list(gmm_list_o ,gmm_list)!=0)
Py_RETURN_NONE;
std::vector<blitz::Array<double,1> > channel_offset_list;
if(extract_array_list(channel_offset_list_o ,channel_offset_list)!=0)
Py_RETURN_NONE;
blitz::Array<double, 2> scores = blitz::Array<double, 2>(gmm_list.size(), stats_list.size());
blitz::Array<double, 2> scores = blitz::Array<double, 2>(gmm_list.size(), stats_list.size());
if(channel_offset_list.size()==0)
bob::learn::misc::linearScoring(gmm_list, *ubm->cxx, stats_list, f(frame_length_normalisation),scores);
else
bob::learn::misc::linearScoring(gmm_list, *ubm->cxx, stats_list, channel_offset_list, f(frame_length_normalisation),scores);
return PyBlitzArrayCxx_AsConstNumpy(scores);
}
default:{
PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - linear_scoring requires 5 or 6 arguments, but you provided %d (see help)", nargs);
linear_scoring.print_usage();
Py_RETURN_NONE;
}
return PyBlitzArrayCxx_AsConstNumpy(scores);
}
else{
PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - linear_scoring requires 5 or 6 arguments, but you provided %d (see help)", nargs);
linear_scoring.print_usage();
Py_RETURN_NONE;
}
/*
......
......@@ -32,6 +32,13 @@ static PyMethodDef module_methods[] = {
METH_VARARGS|METH_KEYWORDS,
z_norm.doc()
},
{
linear_scoring.name(),
(PyCFunction)PyBobLearnMisc_linear_scoring,
METH_VARARGS|METH_KEYWORDS,
linear_scoring.doc()
},
{0}//Sentinel
};
......
......@@ -64,10 +64,10 @@ def test_LinearScoring():
# 1/b/ Without test_channelOffset, with frame-length normalisation
scores = linear_scoring([model1, model2], ubm, [stats1, stats2, stats3], [], True)
assert (abs(scores - ref_scores_01) < 1e-7).all()
scores = linear_scoring([model1, model2], ubm, [stats1, stats2, stats3], (), True)
assert (abs(scores - ref_scores_01) < 1e-7).all()
scores = linear_scoring([model1, model2], ubm, [stats1, stats2, stats3], None, True)
#scores = linear_scoring([model1, model2], ubm, [stats1, stats2, stats3], (), True)
assert (abs(scores - ref_scores_01) < 1e-7).all()
#scores = linear_scoring([model1, model2], ubm, [stats1, stats2, stats3], None, True)
#assert (abs(scores - ref_scores_01) < 1e-7).all()
# 1/c/ With test_channelOffset, without frame-length normalisation
scores = linear_scoring([model1, model2], ubm, [stats1, stats2, stats3], test_channeloffset)
......@@ -77,7 +77,7 @@ def test_LinearScoring():
scores = linear_scoring([model1, model2], ubm, [stats1, stats2, stats3], test_channeloffset, True)
assert (abs(scores - ref_scores_11) < 1e-7).all()
"""
# 2/ Use mean/variance supervectors
# 2/a/ Without test_channelOffset, without frame-length normalisation
scores = linear_scoring([model1.mean_supervector, model2.mean_supervector], ubm.mean_supervector, ubm.variance_supervector, [stats1, stats2, stats3])
......@@ -123,3 +123,4 @@ def test_LinearScoring():
assert abs(score - ref_scores_11[1,1]) < 1e-7
score = linear_scoring(model2.mean_supervector, ubm.mean_supervector, ubm.variance_supervector, stats3, test_channeloffset[2], True)
assert abs(score - ref_scores_11[1,2]) < 1e-7
"""
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment