Skip to content
Snippets Groups Projects
Commit 01ac7f66 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Added tests for the new losses

parent 89ea404b
No related branches found
No related tags found
1 merge request!75A lot of new features
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Amir Mohammadi <amir.mohammadi@idiap.ch>
import tensorflow as tf
def balanced_softmax_cross_entropy_loss_weights(labels, dtype):
def balanced_softmax_cross_entropy_loss_weights(labels, dtype="float32"):
"""Computes weights that normalizes your loss per class.
Labels must be a batch of one-hot encoded labels. The function takes labels and
......@@ -82,7 +86,7 @@ def balanced_softmax_cross_entropy_loss_weights(labels, dtype):
return weights
def balanced_sigmoid_cross_entropy_loss_weights(labels, dtype):
def balanced_sigmoid_cross_entropy_loss_weights(labels, dtype="float32"):
"""Computes weights that normalizes your loss per class.
Labels must be a batch of binary labels. The function takes labels and
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import tensorflow as tf
import numpy
from bob.learn.tensorflow.loss import balanced_softmax_cross_entropy_loss_weights,\
balanced_sigmoid_cross_entropy_loss_weights
def test_balanced_softmax_cross_entropy_loss_weights():
labels = numpy.array([[1, 0, 0],
[1, 0, 0],
[0, 0, 1],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0],
[1, 0, 0],
[0, 0, 1],
[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
[0, 1, 0],
[1, 0, 0],
[0, 1, 0],
[1, 0, 0],
[0, 0, 1],
[0, 0, 1],
[1, 0, 0],
[0, 0, 1],
[1, 0, 0],
[1, 0, 0],
[0, 1, 0],
[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
[0, 1, 0],
[1, 0, 0],
[0, 0, 1],
[1, 0, 0]], dtype="int32")
with tf.Session() as session:
weights = session.run(balanced_softmax_cross_entropy_loss_weights(labels))
expected_weights = numpy.array([0.53333336, 0.53333336, 1.5238096 , 2.1333334,\
1.5238096 , 0.53333336, 0.53333336, 1.5238096,\
0.53333336, 0.53333336, 0.53333336, 0.53333336,\
0.53333336, 0.53333336, 2.1333334 , 0.53333336,\
2.1333334 , 0.53333336, 1.5238096 , 1.5238096 ,\
0.53333336, 1.5238096 , 0.53333336, 0.53333336,\
2.1333334 , 0.53333336, 0.53333336, 0.53333336,\
2.1333334 , 0.53333336, 1.5238096 , 0.53333336],\
dtype="float32")
assert numpy.allclose(weights, expected_weights)
def test_balanced_sigmoid_cross_entropy_loss_weights():
labels = numpy.array([1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0,
1, 1, 0, 1, 1, 1, 0, 1, 0, 1], dtype="int32")
with tf.Session() as session:
weights = session.run(balanced_sigmoid_cross_entropy_loss_weights(labels, dtype='float32'))
expected_weights = numpy.array([0.8, 0.8, 1.3333334, 1.3333334, 1.3333334, 0.8,
0.8, 1.3333334, 0.8, 0.8, 0.8, 0.8,
0.8, 0.8, 1.3333334, 0.8, 1.3333334, 0.8,
1.3333334, 1.3333334, 0.8, 1.3333334, 0.8, 0.8,
1.3333334, 0.8, 0.8, 0.8, 1.3333334, 0.8,
1.3333334, 0.8], dtype="float32")
assert numpy.allclose(weights, expected_weights)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment