Skip to content
Snippets Groups Projects
Commit 865666ad authored by Manuel Günther's avatar Manuel Günther
Browse files

Fixed wrong input for training in Guide; added check to detect that in future.

parent 1305a457
No related branches found
No related tags found
No related merge requests found
......@@ -172,6 +172,8 @@ void bob::learn::mlp::BackProp::train(bob::learn::mlp::Machine& machine,
}
bob::core::array::assertSameDimensionLength(getBatchSize(), input.extent(0));
bob::core::array::assertSameDimensionLength(getBatchSize(), target.extent(0));
bob::core::array::assertSameDimensionLength(machine.inputSize(), input.extent(1));
bob::core::array::assertSameDimensionLength(machine.outputSize(), target.extent(1));
train_(machine, input, target);
}
......
......@@ -216,6 +216,8 @@ void bob::learn::mlp::RProp::train(bob::learn::mlp::Machine& machine,
}
bob::core::array::assertSameDimensionLength(getBatchSize(), input.extent(0));
bob::core::array::assertSameDimensionLength(getBatchSize(), target.extent(0));
bob::core::array::assertSameDimensionLength(machine.inputSize(), input.extent(1));
bob::core::array::assertSameDimensionLength(machine.outputSize(), target.extent(1));
train_(machine, input, target);
}
......
......@@ -116,7 +116,7 @@ available MLP trainers in two different 2D `NumPy`_ arrays, one for the input
.. doctest::
:options: +NORMALIZE_WHITESPACE
>>> d0 = numpy.array([[.3, .7]]) # input
>>> d0 = numpy.array([[.3, .7, .5]]) # input
>>> t0 = numpy.array([[.0]]) # target
The class used to train a MLP [1]_ with backpropagation [2]_ is
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment