Skip to content
Snippets Groups Projects
Commit 65b7398f authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add GAN tools

parent 68116323
Branches
No related tags found
No related merge requests found
from . import spectral_normalization
from . import losses
import tensorflow as tf
def relativistic_discriminator_loss(
discriminator_real_outputs,
discriminator_gen_outputs,
label_smoothing=0.25,
real_weights=1.0,
generated_weights=1.0,
scope=None,
loss_collection=tf.GraphKeys.LOSSES,
reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False,
):
"""Relativistic (average) loss
Args:
discriminator_real_outputs: Discriminator output on real data.
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
label_smoothing: The amount of smoothing for positive labels. This technique
is taken from `Improved Techniques for Training GANs`
(https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
real_weights: Optional `Tensor` whose rank is either 0, or the same rank as
`real_data`, and must be broadcastable to `real_data` (i.e., all
dimensions must be either `1`, or the same as the corresponding
dimension).
generated_weights: Same as `real_weights`, but for `generated_data`.
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.compat.v1.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
"""
with tf.name_scope(
scope,
"discriminator_relativistic_loss",
(
discriminator_real_outputs,
discriminator_gen_outputs,
real_weights,
generated_weights,
label_smoothing,
),
) as scope:
real_logit = discriminator_real_outputs - tf.reduce_mean(
discriminator_gen_outputs
)
fake_logit = discriminator_gen_outputs - tf.reduce_mean(
discriminator_real_outputs
)
loss_on_real = tf.losses.sigmoid_cross_entropy(
tf.ones_like(real_logit),
real_logit,
real_weights,
label_smoothing,
scope,
loss_collection=None,
reduction=reduction,
)
loss_on_generated = tf.losses.sigmoid_cross_entropy(
tf.zeros_like(fake_logit),
fake_logit,
generated_weights,
scope=scope,
loss_collection=None,
reduction=reduction,
)
loss = loss_on_real + loss_on_generated
tf.losses.add_loss(loss, loss_collection)
if add_summaries:
tf.summary.scalar("discriminator_gen_relativistic_loss", loss_on_generated)
tf.summary.scalar("discriminator_real_relativistic_loss", loss_on_real)
tf.summary.scalar("discriminator_relativistic_loss", loss)
return loss
def relativistic_generator_loss(
discriminator_real_outputs,
discriminator_gen_outputs,
label_smoothing=0.0,
real_weights=1.0,
generated_weights=1.0,
scope=None,
loss_collection=tf.GraphKeys.LOSSES,
reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False,
confusion_labels=False,
):
"""Relativistic (average) loss
Args:
discriminator_real_outputs: Discriminator output on real data.
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
label_smoothing: The amount of smoothing for positive labels. This technique
is taken from `Improved Techniques for Training GANs`
(https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
real_weights: Optional `Tensor` whose rank is either 0, or the same rank as
`real_data`, and must be broadcastable to `real_data` (i.e., all
dimensions must be either `1`, or the same as the corresponding
dimension).
generated_weights: Same as `real_weights`, but for `generated_data`.
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.compat.v1.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
"""
with tf.name_scope(
scope,
"generator_relativistic_loss",
(
discriminator_real_outputs,
discriminator_gen_outputs,
real_weights,
generated_weights,
label_smoothing,
),
) as scope:
real_logit = discriminator_real_outputs - tf.reduce_mean(
discriminator_gen_outputs
)
fake_logit = discriminator_gen_outputs - tf.reduce_mean(
discriminator_real_outputs
)
if confusion_labels:
real_labels = tf.ones_like(real_logit) / 2
fake_labels = tf.ones_like(fake_logit) / 2
else:
real_labels = tf.zeros_like(real_logit)
fake_labels = tf.ones_like(fake_logit)
loss_on_real = tf.losses.sigmoid_cross_entropy(
real_labels,
real_logit,
real_weights,
label_smoothing,
scope,
loss_collection=None,
reduction=reduction,
)
loss_on_generated = tf.losses.sigmoid_cross_entropy(
fake_labels,
fake_logit,
generated_weights,
scope=scope,
loss_collection=None,
reduction=reduction,
)
loss = loss_on_real + loss_on_generated
tf.losses.add_loss(loss, loss_collection)
if add_summaries:
tf.summary.scalar("generator_gen_relativistic_loss", loss_on_generated)
tf.summary.scalar("generator_real_relativistic_loss", loss_on_real)
tf.summary.scalar("generator_relativistic_loss", loss)
return loss
# Copied from: https://github.com/tensorflow/tensorflow/blob/c4f40aea1d4f916aa3dfeb79f024c495ac609106/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-like layers and utilities that implement Spectral Normalization.
Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato,
et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT-
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import numbers
import re
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras.engine import base_layer_utils as keras_base_layer_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
__all__ = [
'compute_spectral_norm', 'spectral_normalize', 'spectral_norm_regularizer',
'spectral_normalization_custom_getter', 'keras_spectral_normalization'
]
# tf.bfloat16 should work, but tf.matmul converts those to tf.float32 which then
# can't directly be assigned back to the tf.bfloat16 variable.
_OK_DTYPES_FOR_SPECTRAL_NORM = (dtypes.float16, dtypes.float32, dtypes.float64)
_PERSISTED_U_VARIABLE_SUFFIX = 'spectral_norm_u'
def compute_spectral_norm(w_tensor, power_iteration_rounds=1, name=None):
"""Estimates the largest singular value in the weight tensor.
Args:
w_tensor: The weight matrix whose spectral norm should be computed.
power_iteration_rounds: The number of iterations of the power method to
perform. A higher number yields a better approximation.
name: An optional scope name.
Returns:
The largest singular value (the spectral norm) of w.
"""
with variable_scope.variable_scope(name, 'spectral_norm'):
# The paper says to flatten convnet kernel weights from
# (C_out, C_in, KH, KW) to (C_out, C_in * KH * KW). But TensorFlow's Conv2D
# kernel weight shape is (KH, KW, C_in, C_out), so it should be reshaped to
# (KH * KW * C_in, C_out), and similarly for other layers that put output
# channels as last dimension.
# n.b. this means that w here is equivalent to w.T in the paper.
w = array_ops.reshape(w_tensor, (-1, w_tensor.get_shape()[-1]))
# Persisted approximation of first left singular vector of matrix `w`.
u_var = variable_scope.get_variable(
_PERSISTED_U_VARIABLE_SUFFIX,
shape=(w.shape[0], 1),
dtype=w.dtype,
initializer=init_ops.random_normal_initializer(),
trainable=False)
u = u_var
# Use power iteration method to approximate spectral norm.
for _ in range(power_iteration_rounds):
# `v` approximates the first right singular vector of matrix `w`.
v = nn.l2_normalize(math_ops.matmul(array_ops.transpose(w), u))
u = nn.l2_normalize(math_ops.matmul(w, v))
# Update persisted approximation.
with ops.control_dependencies([u_var.assign(u, name='update_u')]):
u = array_ops.identity(u)
u = array_ops.stop_gradient(u)
v = array_ops.stop_gradient(v)
# Largest singular value of `w`.
spectral_norm = math_ops.matmul(
math_ops.matmul(array_ops.transpose(u), w), v)
spectral_norm.shape.assert_is_fully_defined()
spectral_norm.shape.assert_is_compatible_with([1, 1])
return spectral_norm[0][0]
def spectral_normalize(w, power_iteration_rounds=1, name=None):
"""Normalizes a weight matrix by its spectral norm.
Args:
w: The weight matrix to be normalized.
power_iteration_rounds: The number of iterations of the power method to
perform. A higher number yields a better approximation.
name: An optional scope name.
Returns:
A normalized weight matrix tensor.
"""
with variable_scope.variable_scope(name, 'spectral_normalize'):
w_normalized = w / compute_spectral_norm(
w, power_iteration_rounds=power_iteration_rounds)
return array_ops.reshape(w_normalized, w.get_shape())
def spectral_norm_regularizer(scale, power_iteration_rounds=1, scope=None):
"""Returns a functions that can be used to apply spectral norm regularization.
Small spectral norms enforce a small Lipschitz constant, which is necessary
for Wasserstein GANs.
Args:
scale: A scalar multiplier. 0.0 disables the regularizer.
power_iteration_rounds: The number of iterations of the power method to
perform. A higher number yields a better approximation.
scope: An optional scope name.
Returns:
A function with the signature `sn(weights)` that applies spectral norm
regularization.
Raises:
ValueError: If scale is negative or if scale is not a float.
"""
if isinstance(scale, numbers.Integral):
raise ValueError('scale cannot be an integer: %s' % scale)
if isinstance(scale, numbers.Real):
if scale < 0.0:
raise ValueError(
'Setting a scale less than 0 on a regularizer: %g' % scale)
if scale == 0.0:
logging.info('Scale of 0 disables regularizer.')
return lambda _: None
def sn(weights, name=None):
"""Applies spectral norm regularization to weights."""
with ops.name_scope(scope, 'SpectralNormRegularizer', [weights]) as name:
scale_t = ops.convert_to_tensor(
scale, dtype=weights.dtype.base_dtype, name='scale')
return math_ops.multiply(
scale_t,
compute_spectral_norm(
weights, power_iteration_rounds=power_iteration_rounds),
name=name)
return sn
def _default_name_filter(name):
"""A filter function to identify common names of weight variables.
Args:
name: The variable name.
Returns:
Whether `name` is a standard name for a weight/kernel variables used in the
Keras, tf.layers, tf.contrib.layers or tf.contrib.slim libraries.
"""
match = re.match(r'(.*\/)?(depthwise_|pointwise_)?(weights|kernel)$', name)
return match is not None
def spectral_normalization_custom_getter(name_filter=_default_name_filter,
power_iteration_rounds=1):
"""Custom getter that performs Spectral Normalization on a weight tensor.
Specifically it divides the weight tensor by its largest singular value. This
is intended to stabilize GAN training, by making the discriminator satisfy a
local 1-Lipschitz constraint.
Based on [Spectral Normalization for Generative Adversarial Networks][sn-gan].
[sn-gan]: https://openreview.net/forum?id=B1QRgziT-
To reproduce an SN-GAN, apply this custom_getter to every weight tensor of
your discriminator. The last dimension of the weight tensor must be the number
of output channels.
Apply this to layers by supplying this as the `custom_getter` of a
`tf.compat.v1.variable_scope`. For example:
with tf.compat.v1.variable_scope('discriminator',
custom_getter=spectral_norm_getter()):
net = discriminator_fn(net)
IMPORTANT: Keras does not respect the custom_getter supplied by the
VariableScope, so Keras users should use `keras_spectral_normalization`
instead of (or in addition to) this approach.
It is important to carefully select to which weights you want to apply
Spectral Normalization. In general you want to normalize the kernels of
convolution and dense layers, but you do not want to normalize biases. You
also want to avoid normalizing batch normalization (and similar) variables,
but in general such layers play poorly with Spectral Normalization, since the
gamma can cancel out the normalization in other layers. By default we supply a
filter that matches the kernel variable names of the dense and convolution
layers of the tf.layers, tf.contrib.layers, tf.keras and tf.contrib.slim
libraries. If you are using anything else you'll need a custom `name_filter`.
This custom getter internally creates a variable used to compute the spectral
norm by power iteration. It will update every time the variable is accessed,
which means the normalized discriminator weights may change slightly whilst
training the generator. Whilst unusual, this matches how the paper's authors
implement it, and in general additional rounds of power iteration can't hurt.
Args:
name_filter: Optionally, a method that takes a Variable name as input and
returns whether this Variable should be normalized.
power_iteration_rounds: The number of iterations of the power method to
perform per step. A higher number yields a better approximation of the
true spectral norm.
Returns:
A custom getter function that applies Spectral Normalization to all
Variables whose names match `name_filter`.
Raises:
ValueError: If name_filter is not callable.
"""
if not callable(name_filter):
raise ValueError('name_filter must be callable')
def _internal_getter(getter, name, *args, **kwargs):
"""A custom getter function that applies Spectral Normalization.
Args:
getter: The true getter to call.
name: Name of new/existing variable, in the same format as
tf.get_variable.
*args: Other positional arguments, in the same format as tf.get_variable.
**kwargs: Keyword arguments, in the same format as tf.get_variable.
Returns:
The return value of `getter(name, *args, **kwargs)`, spectrally
normalized.
Raises:
ValueError: If used incorrectly, or if `dtype` is not supported.
"""
if not name_filter(name):
return getter(name, *args, **kwargs)
if name.endswith(_PERSISTED_U_VARIABLE_SUFFIX):
raise ValueError(
'Cannot apply Spectral Normalization to internal variables created '
'for Spectral Normalization. Tried to normalized variable [%s]' %
name)
if kwargs['dtype'] not in _OK_DTYPES_FOR_SPECTRAL_NORM:
raise ValueError('Disallowed data type {}'.format(kwargs['dtype']))
# This layer's weight Variable/PartitionedVariable.
w_tensor = getter(name, *args, **kwargs)
if len(w_tensor.get_shape()) < 2:
raise ValueError(
'Spectral norm can only be applied to multi-dimensional tensors')
return spectral_normalize(
w_tensor,
power_iteration_rounds=power_iteration_rounds,
name=(name + '/spectral_normalize'))
return _internal_getter
@contextlib.contextmanager
def keras_spectral_normalization(name_filter=_default_name_filter,
power_iteration_rounds=1):
"""A context manager that enables Spectral Normalization for Keras.
Keras doesn't respect the `custom_getter` in the VariableScope, so this is a
bit of a hack to make things work.
Usage:
with keras_spectral_normalization():
net = discriminator_fn(net)
Args:
name_filter: Optionally, a method that takes a Variable name as input and
returns whether this Variable should be normalized.
power_iteration_rounds: The number of iterations of the power method to
perform per step. A higher number yields a better approximation of the
true spectral norm.
Yields:
A context manager that wraps the standard Keras variable creation method
with the `spectral_normalization_custom_getter`.
"""
original_make_variable = keras_base_layer_utils.make_variable
sn_getter = spectral_normalization_custom_getter(
name_filter=name_filter, power_iteration_rounds=power_iteration_rounds)
def make_variable_wrapper(name, *args, **kwargs):
return sn_getter(original_make_variable, name, *args, **kwargs)
keras_base_layer_utils.make_variable = make_variable_wrapper
yield
keras_base_layer_utils.make_variable = original_make_variable
import tensorflow as tf
from ..gan.spectral_normalization import spectral_norm_regularizer
from ..utils import gram_matrix
class ConvDiscriminator(tf.keras.Model):
"""A discriminator that can sit on top of DenseNet 161's transition 1 block.
The output of that block given 224x224 inputs is 14x14x384."""
The output of that block given 224x224x3 inputs is 14x14x384."""
def __init__(self, data_format="channels_last", n_classes=1, **kwargs):
super().__init__(**kwargs)
......@@ -13,10 +15,10 @@ class ConvDiscriminator(tf.keras.Model):
self.sequential_layers = [
tf.keras.layers.Conv2D(200, 1, data_format=data_format),
tf.keras.layers.Activation("relu"),
tf.layers.AveragePooling2D(3, 2, data_format=data_format),
tf.keras.layers.AveragePooling2D(3, 2, data_format=data_format),
tf.keras.layers.Conv2D(100, 1, data_format=data_format),
tf.keras.layers.Activation("relu"),
tf.layers.AveragePooling2D(3, 2, data_format=data_format),
tf.keras.layers.AveragePooling2D(3, 2, data_format=data_format),
tf.keras.layers.Flatten(data_format=data_format),
tf.keras.layers.Dense(n_classes),
tf.keras.layers.Activation(act),
......@@ -24,7 +26,10 @@ class ConvDiscriminator(tf.keras.Model):
def call(self, x, training=None):
for l in self.sequential_layers:
x = l(x)
try:
x = l(x, training=training)
except TypeError:
x = l(x)
return x
......@@ -66,5 +71,89 @@ class ConvDiscriminator2(tf.keras.Model):
def call(self, x, training=None):
for l in self.sequential_layers:
x = l(x)
try:
x = l(x, training=training)
except TypeError:
x = l(x)
return x
class ConvDiscriminator3(tf.keras.Model):
"""A discriminator that takes images and tries its best.
Be careful, this one returns logits."""
def __init__(self, data_format="channels_last", n_classes=1, **kwargs):
super().__init__(**kwargs)
self.data_format = data_format
self.n_classes = n_classes
spectral_norm = spectral_norm_regularizer(scale=1.0)
conv2d_kw = {"kernel_regularizer": spectral_norm, "data_format": data_format}
self.sequential_layers = [
tf.keras.layers.Conv2D(64, 3, strides=1, **conv2d_kw),
tf.keras.layers.LeakyReLU(0.1),
tf.keras.layers.Conv2D(64, 4, strides=2, **conv2d_kw),
tf.keras.layers.LeakyReLU(0.1),
tf.keras.layers.Conv2D(128, 3, strides=1, **conv2d_kw),
tf.keras.layers.LeakyReLU(0.1),
tf.keras.layers.Conv2D(128, 4, strides=2, **conv2d_kw),
tf.keras.layers.LeakyReLU(0.1),
tf.keras.layers.Conv2D(256, 3, strides=1, **conv2d_kw),
tf.keras.layers.LeakyReLU(0.1),
tf.keras.layers.Conv2D(256, 4, strides=2, **conv2d_kw),
tf.keras.layers.LeakyReLU(0.1),
tf.keras.layers.Conv2D(512, 3, strides=1, **conv2d_kw),
tf.keras.layers.LeakyReLU(0.1),
tf.keras.layers.GlobalAveragePooling2D(data_format=data_format),
tf.keras.layers.Dense(n_classes),
]
def call(self, x, training=None):
for l in self.sequential_layers:
try:
x = l(x, training=training)
except TypeError:
x = l(x)
return x
class DenseDiscriminator(tf.keras.Model):
"""A discriminator that takes vectors as input and tries its best.
Be careful, this one returns logits."""
def __init__(self, n_classes=1, **kwargs):
super().__init__(**kwargs)
self.n_classes = n_classes
self.sequential_layers = [
tf.keras.layers.Dense(1000),
tf.keras.layers.Activation("relu"),
tf.keras.layers.Dense(1000),
tf.keras.layers.Activation("relu"),
tf.keras.layers.Dense(n_classes),
]
def call(self, x, training=None):
for l in self.sequential_layers:
try:
x = l(x, training=training)
except TypeError:
x = l(x)
return x
class GramComparer1(tf.keras.Model):
"""A model to compare images based on their gram matrices."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.batchnorm = tf.keras.layers.BatchNormalization()
self.conv2d = tf.keras.layers.Conv2D(128, 7)
def call(self, x_1_2, training=None):
def _call(x):
x = self.batchnorm(x, training=training)
x = self.conv2d(x)
return gram_matrix(x)
gram1 = _call(x_1_2[..., :3])
gram2 = _call(x_1_2[..., 3:])
return -tf.reduce_mean((gram1 - gram2) ** 2, axis=[1, 2])[:, None]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment