Skip to content
Snippets Groups Projects
Commit 13ae3919 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add two functions to allow dynamic weighting of samples per batch

parent 214fad7a
No related branches found
No related tags found
1 merge request!75A lot of new features
import tensorflow as tf
def balanced_softmax_cross_entropy_loss_weights(labels, dtype):
  • We should set dtype as float32 as default. This is true for the vast majority of the cases, no?

  • Author Owner

    Fine by me. They should have the same type as logits always that's why I did not make it optional.

    dtype : dtype
            The dtype that weights will have. It should be float. Best is to provide
            logits.dtype as input.
  • Please register or sign in to reply
"""Computes weights that normalizes your loss per class.
Labels must be a batch of one-hot encoded labels. The function takes labels and
computes the weights per batch. Weights will be smaller for classes that have more
samples in this batch. This is useful if you unbalanced classes in your dataset or
batch.
Parameters
----------
labels : tf.Tensor
Labels of your current input. The shape must be [batch_size, n_classes]. If your
labels are not one-hot encoded, you can use ``tf.one_hot`` to convert them first
before giving them to this function.
dtype : dtype
The dtype that weights will have. It should be float. Best is to provide
logits.dtype as input.
Returns
-------
tf.Tensor
Computed weights that will cancel your dataset imbalance per batch.
Examples
--------
>>> labels = array([[1, 0, 0],
... [1, 0, 0],
... [0, 0, 1],
... [0, 1, 0],
... [0, 0, 1],
... [1, 0, 0],
... [1, 0, 0],
... [0, 0, 1],
... [1, 0, 0],
... [1, 0, 0],
... [1, 0, 0],
... [1, 0, 0],
... [1, 0, 0],
... [1, 0, 0],
... [0, 1, 0],
... [1, 0, 0],
... [0, 1, 0],
... [1, 0, 0],
... [0, 0, 1],
... [0, 0, 1],
... [1, 0, 0],
... [0, 0, 1],
... [1, 0, 0],
... [1, 0, 0],
... [0, 1, 0],
... [1, 0, 0],
... [1, 0, 0],
... [1, 0, 0],
... [0, 1, 0],
... [1, 0, 0],
... [0, 0, 1],
... [1, 0, 0]], dtype=int32)
>>> tf.reduce_sum(labels, axis=0)
array([20, 5, 7], dtype=int32)
>>> balanced_softmax_cross_entropy_loss_weights(labels, dtype='float32')
array([0.53333336, 0.53333336, 1.5238096 , 2.1333334 , 1.5238096 ,
0.53333336, 0.53333336, 1.5238096 , 0.53333336, 0.53333336,
0.53333336, 0.53333336, 0.53333336, 0.53333336, 2.1333334 ,
0.53333336, 2.1333334 , 0.53333336, 1.5238096 , 1.5238096 ,
0.53333336, 1.5238096 , 0.53333336, 0.53333336, 2.1333334 ,
0.53333336, 0.53333336, 0.53333336, 2.1333334 , 0.53333336,
1.5238096 , 0.53333336], dtype=float32)
You would use it like this:
>>> weights = balanced_softmax_cross_entropy_loss_weights(labels, dtype=logits.dtype)
>>> loss = tf.losses.softmax_cross_entropy(logits=logits, labels=labels, weights=weights)
"""
shape = tf.cast(tf.shape(labels), dtype=dtype)
batch_size, n_classes = shape[0], shape[1]
weights = tf.cast(tf.reduce_sum(labels, axis=0), dtype=dtype)
weights = batch_size / weights / n_classes
weights = tf.gather(weights, tf.argmax(labels, axis=1))
return weights
def balanced_sigmoid_cross_entropy_loss_weights(labels, dtype):
"""Computes weights that normalizes your loss per class.
Labels must be a batch of binary labels. The function takes labels and
computes the weights per batch. Weights will be smaller for the class that have more
samples in this batch. This is useful if you unbalanced classes in your dataset or
batch.
Parameters
----------
labels : tf.Tensor
Labels of your current input. The shape must be [batch_size] and values must be
either 0 or 1.
dtype : dtype
The dtype that weights will have. It should be float. Best is to provide
logits.dtype as input.
Returns
-------
tf.Tensor
Computed weights that will cancel your dataset imbalance per batch.
Examples
--------
>>> labels = array([1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0,
... 1, 1, 0, 1, 1, 1, 0, 1, 0, 1], dtype=int32)
>>> sum(labels), len(labels)
20, 32
>>> balanced_sigmoid_cross_entropy_loss_weights(labels, dtype='float32')
array([0.8 , 0.8 , 1.3333334, 1.3333334, 1.3333334, 0.8 ,
0.8 , 1.3333334, 0.8 , 0.8 , 0.8 , 0.8 ,
0.8 , 0.8 , 1.3333334, 0.8 , 1.3333334, 0.8 ,
1.3333334, 1.3333334, 0.8 , 1.3333334, 0.8 , 0.8 ,
1.3333334, 0.8 , 0.8 , 0.8 , 1.3333334, 0.8 ,
1.3333334, 0.8 ], dtype=float32)
You would use it like this:
>>> weights = balanced_sigmoid_cross_entropy_loss_weights(labels, dtype=logits.dtype)
>>> loss = tf.losses.sigmoid_cross_entropy(logits=logits, labels=labels, weights=weights)
"""
labels = tf.cast(labels, dtype='int32')
batch_size = tf.cast(tf.shape(labels)[0], dtype=dtype)
weights = tf.cast(tf.reduce_sum(labels), dtype=dtype)
weights = tf.convert_to_tensor([batch_size - weights, weights])
weights = batch_size / weights / 2
weights = tf.gather(weights, labels)
return weights
  • Owner

    There are no tests for this one and for this we should do it?

    Is there any reference in the literature for this @amohammadi ?

    Thanks

  • Author Owner

    I learned this from @ageorge I don't know if he found it from somewhere.

    There are tests but they are doctests as you can see above.

  • Maintainer

    Well, the reference goes back to @onikisins, it seems to work well in practice for binary tasks.

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment