diff --git a/bob/pad/base/algorithm/SVM.py b/bob/pad/base/algorithm/SVM.py index 9d72554079238d09c1b9298938b21c8300e26f60..f76dea97013af5a52dd482b693b0dfd8002ab54f 100644 --- a/bob/pad/base/algorithm/SVM.py +++ b/bob/pad/base/algorithm/SVM.py @@ -277,6 +277,13 @@ class SVM(Algorithm): machine = trainer.train(data) # train the machine + + # TODO: find a proper way to handle this - Guillaume HEUSCH, 08-08-2018 + if machine.shape[0] != data[0].shape[1]: + data[0] += 1 + data[1] += 1 + machine = trainer.train(data) # train the machine + precision_cv = self.comp_prediction_precision( machine, np.copy(real_cv), np.copy(attack_cv)) @@ -370,6 +377,12 @@ class SVM(Algorithm): machine = trainer.train(data) # train the machine + # TODO: find a proper way to handle this - Guillaume HEUSCH, 08-08-2018 + if machine.shape[0] != data[0].shape[1]: + data[0] += 1 + data[1] += 1 + machine = trainer.train(data) # train the machine + if mean_std_norm_flag: machine.input_subtract = features_mean # subtract the mean of train data machine.input_divide = features_std # divide by std of train data