Commit 38de7bdc authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Add GAN tools

parent cb9744bc
from . import spectral_normalization
from . import losses
import tensorflow as tf
def relativistic_discriminator_loss(
discriminator_real_outputs,
discriminator_gen_outputs,
label_smoothing=0.25,
real_weights=1.0,
generated_weights=1.0,
scope=None,
loss_collection=tf.GraphKeys.LOSSES,
reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False,
):
"""Relativistic (average) loss
Args:
discriminator_real_outputs: Discriminator output on real data.
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
label_smoothing: The amount of smoothing for positive labels. This technique
is taken from `Improved Techniques for Training GANs`
(https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
real_weights: Optional `Tensor` whose rank is either 0, or the same rank as
`real_data`, and must be broadcastable to `real_data` (i.e., all
dimensions must be either `1`, or the same as the corresponding
dimension).
generated_weights: Same as `real_weights`, but for `generated_data`.
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.compat.v1.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
"""
with tf.name_scope(
scope,
"discriminator_relativistic_loss",
(
discriminator_real_outputs,
discriminator_gen_outputs,
real_weights,
generated_weights,
label_smoothing,
),
) as scope:
real_logit = discriminator_real_outputs - tf.reduce_mean(
discriminator_gen_outputs
)
fake_logit = discriminator_gen_outputs - tf.reduce_mean(
discriminator_real_outputs
)
loss_on_real = tf.losses.sigmoid_cross_entropy(
tf.ones_like(real_logit),
real_logit,
real_weights,
label_smoothing,
scope,
loss_collection=None,
reduction=reduction,
)
loss_on_generated = tf.losses.sigmoid_cross_entropy(
tf.zeros_like(fake_logit),
fake_logit,
generated_weights,
scope=scope,
loss_collection=None,
reduction=reduction,
)
loss = loss_on_real + loss_on_generated
tf.losses.add_loss(loss, loss_collection)
if add_summaries:
tf.summary.scalar("discriminator_gen_relativistic_loss", loss_on_generated)
tf.summary.scalar("discriminator_real_relativistic_loss", loss_on_real)
tf.summary.scalar("discriminator_relativistic_loss", loss)
return loss
def relativistic_generator_loss(
discriminator_real_outputs,
discriminator_gen_outputs,
label_smoothing=0.0,
real_weights=1.0,
generated_weights=1.0,
scope=None,
loss_collection=tf.GraphKeys.LOSSES,
reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False,
confusion_labels=False,
):
"""Relativistic (average) loss
Args:
discriminator_real_outputs: Discriminator output on real data.
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
label_smoothing: The amount of smoothing for positive labels. This technique
is taken from `Improved Techniques for Training GANs`
(https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
real_weights: Optional `Tensor` whose rank is either 0, or the same rank as
`real_data`, and must be broadcastable to `real_data` (i.e., all
dimensions must be either `1`, or the same as the corresponding
dimension).
generated_weights: Same as `real_weights`, but for `generated_data`.
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.compat.v1.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
"""
with tf.name_scope(
scope,
"generator_relativistic_loss",
(
discriminator_real_outputs,
discriminator_gen_outputs,
real_weights,
generated_weights,
label_smoothing,
),
) as scope:
real_logit = discriminator_real_outputs - tf.reduce_mean(
discriminator_gen_outputs
)
fake_logit = discriminator_gen_outputs - tf.reduce_mean(
discriminator_real_outputs
)
if confusion_labels:
real_labels = tf.ones_like(real_logit) / 2
fake_labels = tf.ones_like(fake_logit) / 2
else:
real_labels = tf.zeros_like(real_logit)
fake_labels = tf.ones_like(fake_logit)
loss_on_real = tf.losses.sigmoid_cross_entropy(
real_labels,
real_logit,
real_weights,
label_smoothing,
scope,
loss_collection=None,
reduction=reduction,
)
loss_on_generated = tf.losses.sigmoid_cross_entropy(
fake_labels,
fake_logit,
generated_weights,
scope=scope,
loss_collection=None,
reduction=reduction,
)
loss = loss_on_real + loss_on_generated
tf.losses.add_loss(loss, loss_collection)
if add_summaries:
tf.summary.scalar("generator_gen_relativistic_loss", loss_on_generated)
tf.summary.scalar("generator_real_relativistic_loss", loss_on_real)
tf.summary.scalar("generator_relativistic_loss", loss)
return loss
This diff is collapsed.
import tensorflow as tf
from ..gan.spectral_normalization import spectral_norm_regularizer
from ..utils import gram_matrix
class ConvDiscriminator(tf.keras.Model):
"""A discriminator that can sit on top of DenseNet 161's transition 1 block.
The output of that block given 224x224 inputs is 14x14x384."""
The output of that block given 224x224x3 inputs is 14x14x384."""
def __init__(self, data_format="channels_last", n_classes=1, **kwargs):
super().__init__(**kwargs)
......@@ -13,10 +15,10 @@ class ConvDiscriminator(tf.keras.Model):
self.sequential_layers = [
tf.keras.layers.Conv2D(200, 1, data_format=data_format),
tf.keras.layers.Activation("relu"),
tf.layers.AveragePooling2D(3, 2, data_format=data_format),
tf.keras.layers.AveragePooling2D(3, 2, data_format=data_format),
tf.keras.layers.Conv2D(100, 1, data_format=data_format),
tf.keras.layers.Activation("relu"),
tf.layers.AveragePooling2D(3, 2, data_format=data_format),
tf.keras.layers.AveragePooling2D(3, 2, data_format=data_format),
tf.keras.layers.Flatten(data_format=data_format),
tf.keras.layers.Dense(n_classes),
tf.keras.layers.Activation(act),
......@@ -24,7 +26,10 @@ class ConvDiscriminator(tf.keras.Model):
def call(self, x, training=None):
for l in self.sequential_layers:
x = l(x)
try:
x = l(x, training=training)
except TypeError:
x = l(x)
return x
......@@ -66,5 +71,89 @@ class ConvDiscriminator2(tf.keras.Model):
def call(self, x, training=None):
for l in self.sequential_layers:
x = l(x)
try:
x = l(x, training=training)
except TypeError:
x = l(x)
return x
class ConvDiscriminator3(tf.keras.Model):
"""A discriminator that takes images and tries its best.
Be careful, this one returns logits."""
def __init__(self, data_format="channels_last", n_classes=1, **kwargs):
super().__init__(**kwargs)
self.data_format = data_format
self.n_classes = n_classes
spectral_norm = spectral_norm_regularizer(scale=1.0)
conv2d_kw = {"kernel_regularizer": spectral_norm, "data_format": data_format}
self.sequential_layers = [
tf.keras.layers.Conv2D(64, 3, strides=1, **conv2d_kw),
tf.keras.layers.LeakyReLU(0.1),
tf.keras.layers.Conv2D(64, 4, strides=2, **conv2d_kw),
tf.keras.layers.LeakyReLU(0.1),
tf.keras.layers.Conv2D(128, 3, strides=1, **conv2d_kw),
tf.keras.layers.LeakyReLU(0.1),
tf.keras.layers.Conv2D(128, 4, strides=2, **conv2d_kw),
tf.keras.layers.LeakyReLU(0.1),
tf.keras.layers.Conv2D(256, 3, strides=1, **conv2d_kw),
tf.keras.layers.LeakyReLU(0.1),
tf.keras.layers.Conv2D(256, 4, strides=2, **conv2d_kw),
tf.keras.layers.LeakyReLU(0.1),
tf.keras.layers.Conv2D(512, 3, strides=1, **conv2d_kw),
tf.keras.layers.LeakyReLU(0.1),
tf.keras.layers.GlobalAveragePooling2D(data_format=data_format),
tf.keras.layers.Dense(n_classes),
]
def call(self, x, training=None):
for l in self.sequential_layers:
try:
x = l(x, training=training)
except TypeError:
x = l(x)
return x
class DenseDiscriminator(tf.keras.Model):
"""A discriminator that takes vectors as input and tries its best.
Be careful, this one returns logits."""
def __init__(self, n_classes=1, **kwargs):
super().__init__(**kwargs)
self.n_classes = n_classes
self.sequential_layers = [
tf.keras.layers.Dense(1000),
tf.keras.layers.Activation("relu"),
tf.keras.layers.Dense(1000),
tf.keras.layers.Activation("relu"),
tf.keras.layers.Dense(n_classes),
]
def call(self, x, training=None):
for l in self.sequential_layers:
try:
x = l(x, training=training)
except TypeError:
x = l(x)
return x
class GramComparer1(tf.keras.Model):
"""A model to compare images based on their gram matrices."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.batchnorm = tf.keras.layers.BatchNormalization()
self.conv2d = tf.keras.layers.Conv2D(128, 7)
def call(self, x_1_2, training=None):
def _call(x):
x = self.batchnorm(x, training=training)
x = self.conv2d(x)
return gram_matrix(x)
gram1 = _call(x_1_2[..., :3])
gram2 = _call(x_1_2[..., 3:])
return -tf.reduce_mean((gram1 - gram2) ** 2, axis=[1, 2])[:, None]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment