losses.py 5.8 KB
Newer Older
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
import tensorflow as tf


def relativistic_discriminator_loss(
    discriminator_real_outputs,
    discriminator_gen_outputs,
    label_smoothing=0.25,
    real_weights=1.0,
    generated_weights=1.0,
    scope=None,
    loss_collection=tf.GraphKeys.LOSSES,
    reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
    add_summaries=False,
):
    """Relativistic (average) loss

  Args:
    discriminator_real_outputs: Discriminator output on real data.
    discriminator_gen_outputs: Discriminator output on generated data. Expected
      to be in the range of (-inf, inf).
    label_smoothing: The amount of smoothing for positive labels. This technique
      is taken from `Improved Techniques for Training GANs`
      (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
    real_weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `real_data`, and must be broadcastable to `real_data` (i.e., all
      dimensions must be either `1`, or the same as the corresponding
      dimension).
    generated_weights: Same as `real_weights`, but for `generated_data`.
    scope: The scope for the operations performed in computing the loss.
    loss_collection: collection to which this loss will be added.
    reduction: A `tf.compat.v1.losses.Reduction` to apply to loss.
    add_summaries: Whether or not to add summaries for the loss.

  Returns:
    A loss Tensor. The shape depends on `reduction`.
  """
    with tf.name_scope(
        scope,
        "discriminator_relativistic_loss",
        (
            discriminator_real_outputs,
            discriminator_gen_outputs,
            real_weights,
            generated_weights,
            label_smoothing,
        ),
    ) as scope:

        real_logit = discriminator_real_outputs - tf.reduce_mean(
            discriminator_gen_outputs
        )
        fake_logit = discriminator_gen_outputs - tf.reduce_mean(
            discriminator_real_outputs
        )

        loss_on_real = tf.losses.sigmoid_cross_entropy(
            tf.ones_like(real_logit),
            real_logit,
            real_weights,
            label_smoothing,
            scope,
            loss_collection=None,
            reduction=reduction,
        )
        loss_on_generated = tf.losses.sigmoid_cross_entropy(
            tf.zeros_like(fake_logit),
            fake_logit,
            generated_weights,
            scope=scope,
            loss_collection=None,
            reduction=reduction,
        )

        loss = loss_on_real + loss_on_generated
        tf.losses.add_loss(loss, loss_collection)

        if add_summaries:
            tf.summary.scalar("discriminator_gen_relativistic_loss", loss_on_generated)
            tf.summary.scalar("discriminator_real_relativistic_loss", loss_on_real)
            tf.summary.scalar("discriminator_relativistic_loss", loss)

    return loss


def relativistic_generator_loss(
    discriminator_real_outputs,
    discriminator_gen_outputs,
    label_smoothing=0.0,
    real_weights=1.0,
    generated_weights=1.0,
    scope=None,
    loss_collection=tf.GraphKeys.LOSSES,
    reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
    add_summaries=False,
    confusion_labels=False,
):
    """Relativistic (average) loss

  Args:
    discriminator_real_outputs: Discriminator output on real data.
    discriminator_gen_outputs: Discriminator output on generated data. Expected
      to be in the range of (-inf, inf).
    label_smoothing: The amount of smoothing for positive labels. This technique
      is taken from `Improved Techniques for Training GANs`
      (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
    real_weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `real_data`, and must be broadcastable to `real_data` (i.e., all
      dimensions must be either `1`, or the same as the corresponding
      dimension).
    generated_weights: Same as `real_weights`, but for `generated_data`.
    scope: The scope for the operations performed in computing the loss.
    loss_collection: collection to which this loss will be added.
    reduction: A `tf.compat.v1.losses.Reduction` to apply to loss.
    add_summaries: Whether or not to add summaries for the loss.

  Returns:
    A loss Tensor. The shape depends on `reduction`.
  """
    with tf.name_scope(
        scope,
        "generator_relativistic_loss",
        (
            discriminator_real_outputs,
            discriminator_gen_outputs,
            real_weights,
            generated_weights,
            label_smoothing,
        ),
    ) as scope:

        real_logit = discriminator_real_outputs - tf.reduce_mean(
            discriminator_gen_outputs
        )
        fake_logit = discriminator_gen_outputs - tf.reduce_mean(
            discriminator_real_outputs
        )

        if confusion_labels:
            real_labels = tf.ones_like(real_logit) / 2
            fake_labels = tf.ones_like(fake_logit) / 2
        else:
            real_labels = tf.zeros_like(real_logit)
            fake_labels = tf.ones_like(fake_logit)

        loss_on_real = tf.losses.sigmoid_cross_entropy(
            real_labels,
            real_logit,
            real_weights,
            label_smoothing,
            scope,
            loss_collection=None,
            reduction=reduction,
        )
        loss_on_generated = tf.losses.sigmoid_cross_entropy(
            fake_labels,
            fake_logit,
            generated_weights,
            scope=scope,
            loss_collection=None,
            reduction=reduction,
        )

        loss = loss_on_real + loss_on_generated
        tf.losses.add_loss(loss, loss_collection)

        if add_summaries:
            tf.summary.scalar("generator_gen_relativistic_loss", loss_on_generated)
            tf.summary.scalar("generator_real_relativistic_loss", loss_on_real)
            tf.summary.scalar("generator_relativistic_loss", loss)

    return loss