Skip to content
Snippets Groups Projects
Commit 18e34e61 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add virtual adversarial training loss

parent 0b0055bf
No related branches found
No related tags found
1 merge request!75A lot of new features
...@@ -2,6 +2,8 @@ from .BaseLoss import mean_cross_entropy_loss, mean_cross_entropy_center_loss ...@@ -2,6 +2,8 @@ from .BaseLoss import mean_cross_entropy_loss, mean_cross_entropy_center_loss
from .ContrastiveLoss import contrastive_loss from .ContrastiveLoss import contrastive_loss
from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss
from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss
from .vat import VATLoss
from .utils import *
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
......
# Adapted from https://github.com/takerum/vat_tf Its license:
#
# MIT License
#
# Copyright (c) 2017 Takeru Miyato
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import tensorflow as tf
from functools import partial
def get_normalized_vector(d):
d /= (1e-12 + tf.reduce_max(tf.abs(d), list(range(1, len(d.get_shape()))), keepdims=True))
d /= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), list(range(1, len(d.get_shape()))), keepdims=True))
return d
def logsoftmax(x):
xdev = x - tf.reduce_max(x, 1, keepdims=True)
lsm = xdev - tf.log(tf.reduce_sum(tf.exp(xdev), 1, keepdims=True))
return lsm
def kl_divergence_with_logit(q_logit, p_logit):
q = tf.nn.softmax(q_logit)
qlogq = tf.reduce_mean(tf.reduce_sum(q * logsoftmax(q_logit), 1))
qlogp = tf.reduce_mean(tf.reduce_sum(q * logsoftmax(p_logit), 1))
return qlogq - qlogp
def entropy_y_x(logit):
p = tf.nn.softmax(logit)
return -tf.reduce_mean(tf.reduce_sum(p * logsoftmax(logit), 1))
class VATLoss:
"""A class to hold parameters for Virtual Adversarial Training (VAT) Loss
and perform it.
Attributes
----------
epsilon : float
norm length for (virtual) adversarial training
method : str
The method for calculating the loss: ``vatent`` for VAT loss + entropy
and ``vat`` for only VAT loss.
num_power_iterations : int
the number of power iterations
xi : float
small constant for finite difference
"""
def __init__(self, epsilon=8.0, xi=1e-6, num_power_iterations=1, method='vatent', **kwargs):
super(VATLoss, self).__init__(**kwargs)
self.epsilon = epsilon
self.xi = xi
self.num_power_iterations = num_power_iterations
self.method = method
def __call__(self, features, logits, architecture, mode):
"""Computes the VAT loss for unlabeled features.
If you are doing semi-supervised learning, only pass the unlabeled
features and their logits here.
Parameters
----------
features : object
Tensor representing the (unlabeled) features
logits : object
Tensor representing the logits of (unlabeled) features.
architecture : object
A callable that constructs the model. It should accept ``mode`` and
``reuse`` as keyword arguments. The features will be given as the
first input.
mode : str
One of tf.estimator.ModeKeys.{TRAIN,EVAL} strings.
Returns
-------
object
The loss.
Raises
------
NotImplementedError
If self.method is not ``vat`` or ``vatent``.
"""
architecture = partial(architecture, reuse=True)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
vat_loss = self.virtual_adversarial_loss(features, logits, architecture, mode)
tf.summary.scalar("vat_loss", vat_loss)
tf.add_to_collection(tf.GraphKeys.LOSSES, vat_loss)
if self.method == 'vat':
loss = vat_loss
elif self.method == 'vatent':
ent_loss = entropy_y_x(logits)
tf.summary.scalar("entropy_loss", ent_loss)
tf.add_to_collection(tf.GraphKeys.LOSSES, ent_loss)
loss = vat_loss + ent_loss
else:
raise ValueError
return loss
def virtual_adversarial_loss(self, features, logits, architecture, mode, name="vat_loss"):
r_vadv = self.generate_virtual_adversarial_perturbation(features, logits, architecture, mode)
logit_p = tf.stop_gradient(logits)
adversarial_input = features + r_vadv
tf.summary.image("Adversarial_Image", adversarial_input)
logit_m = architecture(adversarial_input, mode=mode)[0]
loss = kl_divergence_with_logit(logit_p, logit_m)
return tf.identity(loss, name=name)
def generate_virtual_adversarial_perturbation(self, features, logits, architecture, mode):
d = tf.random_normal(shape=tf.shape(features))
for _ in range(self.num_power_iterations):
d = self.xi * get_normalized_vector(d)
logit_p = logits
logit_m = architecture(features + d, mode=mode)[0]
dist = kl_divergence_with_logit(logit_p, logit_m)
grad = tf.gradients(dist, [d], aggregation_method=2)[0]
d = tf.stop_gradient(grad)
return self.epsilon * get_normalized_vector(d)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment