diff --git a/bob/learn/libsvm/test_trainer.py b/bob/learn/libsvm/test_trainer.py index 060326caa34d463886732932f52116e7c6a4971b..74f4b4d045671e3e448c54f6c739e0d0c8ff38dc 100644 --- a/bob/learn/libsvm/test_trainer.py +++ b/bob/learn/libsvm/test_trainer.py @@ -214,4 +214,21 @@ def test_training_one_class(): curr_scores = numpy.array(curr_scores) prev_scores = numpy.array(prev_scores) assert numpy.all(abs(curr_scores - prev_scores) < 1e-8) - + + +def test_successive_training(): + + # Tests successive training works: i.e., training a couple of machines one + # after the other. + + numpy.random.seed(10) + + for i in range(2): + pos = numpy.random.normal(0., 1, size=(100, 2)) + neg = numpy.random.normal(1., 1, size=(100, 2)) + data = [pos, neg] + + trainer = Trainer() + trainer.kernel_type = 'LINEAR' + trainer.cost = 1 + trainer.train(data)