diff --git a/bob/learn/tensorflow/loss/__init__.py b/bob/learn/tensorflow/loss/__init__.py index 17947ea14336862a00238021238975d58ec84895..f35a0cebf7afb894c8ef5671b4517827f06ab9df 100644 --- a/bob/learn/tensorflow/loss/__init__.py +++ b/bob/learn/tensorflow/loss/__init__.py @@ -2,6 +2,8 @@ from .BaseLoss import mean_cross_entropy_loss, mean_cross_entropy_center_loss from .ContrastiveLoss import contrastive_loss from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss +from .vat import VATLoss +from .utils import * # gets sphinx autodoc done right - don't remove it diff --git a/bob/learn/tensorflow/loss/vat.py b/bob/learn/tensorflow/loss/vat.py new file mode 100644 index 0000000000000000000000000000000000000000..0d414f9e867da461f411242ddac572419c98faa5 --- /dev/null +++ b/bob/learn/tensorflow/loss/vat.py @@ -0,0 +1,142 @@ +# Adapted from https://github.com/takerum/vat_tf Its license: +# +# MIT License +# +# Copyright (c) 2017 Takeru Miyato +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import tensorflow as tf +from functools import partial + + +def get_normalized_vector(d): + d /= (1e-12 + tf.reduce_max(tf.abs(d), list(range(1, len(d.get_shape()))), keepdims=True)) + d /= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), list(range(1, len(d.get_shape()))), keepdims=True)) + return d + + +def logsoftmax(x): + xdev = x - tf.reduce_max(x, 1, keepdims=True) + lsm = xdev - tf.log(tf.reduce_sum(tf.exp(xdev), 1, keepdims=True)) + return lsm + + +def kl_divergence_with_logit(q_logit, p_logit): + q = tf.nn.softmax(q_logit) + qlogq = tf.reduce_mean(tf.reduce_sum(q * logsoftmax(q_logit), 1)) + qlogp = tf.reduce_mean(tf.reduce_sum(q * logsoftmax(p_logit), 1)) + return qlogq - qlogp + + +def entropy_y_x(logit): + p = tf.nn.softmax(logit) + return -tf.reduce_mean(tf.reduce_sum(p * logsoftmax(logit), 1)) + + +class VATLoss: + """A class to hold parameters for Virtual Adversarial Training (VAT) Loss + and perform it. + + Attributes + ---------- + epsilon : float + norm length for (virtual) adversarial training + method : str + The method for calculating the loss: ``vatent`` for VAT loss + entropy + and ``vat`` for only VAT loss. + num_power_iterations : int + the number of power iterations + xi : float + small constant for finite difference + """ + + def __init__(self, epsilon=8.0, xi=1e-6, num_power_iterations=1, method='vatent', **kwargs): + super(VATLoss, self).__init__(**kwargs) + self.epsilon = epsilon + self.xi = xi + self.num_power_iterations = num_power_iterations + self.method = method + + def __call__(self, features, logits, architecture, mode): + """Computes the VAT loss for unlabeled features. + If you are doing semi-supervised learning, only pass the unlabeled + features and their logits here. + + Parameters + ---------- + features : object + Tensor representing the (unlabeled) features + logits : object + Tensor representing the logits of (unlabeled) features. + architecture : object + A callable that constructs the model. It should accept ``mode`` and + ``reuse`` as keyword arguments. The features will be given as the + first input. + mode : str + One of tf.estimator.ModeKeys.{TRAIN,EVAL} strings. + + Returns + ------- + object + The loss. + + Raises + ------ + NotImplementedError + If self.method is not ``vat`` or ``vatent``. + """ + architecture = partial(architecture, reuse=True) + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + vat_loss = self.virtual_adversarial_loss(features, logits, architecture, mode) + tf.summary.scalar("vat_loss", vat_loss) + tf.add_to_collection(tf.GraphKeys.LOSSES, vat_loss) + if self.method == 'vat': + loss = vat_loss + elif self.method == 'vatent': + ent_loss = entropy_y_x(logits) + tf.summary.scalar("entropy_loss", ent_loss) + tf.add_to_collection(tf.GraphKeys.LOSSES, ent_loss) + loss = vat_loss + ent_loss + else: + raise ValueError + return loss + + def virtual_adversarial_loss(self, features, logits, architecture, mode, name="vat_loss"): + r_vadv = self.generate_virtual_adversarial_perturbation(features, logits, architecture, mode) + logit_p = tf.stop_gradient(logits) + adversarial_input = features + r_vadv + tf.summary.image("Adversarial_Image", adversarial_input) + logit_m = architecture(adversarial_input, mode=mode)[0] + loss = kl_divergence_with_logit(logit_p, logit_m) + return tf.identity(loss, name=name) + + def generate_virtual_adversarial_perturbation(self, features, logits, architecture, mode): + d = tf.random_normal(shape=tf.shape(features)) + + for _ in range(self.num_power_iterations): + d = self.xi * get_normalized_vector(d) + logit_p = logits + logit_m = architecture(features + d, mode=mode)[0] + dist = kl_divergence_with_logit(logit_p, logit_m) + grad = tf.gradients(dist, [d], aggregation_method=2)[0] + d = tf.stop_gradient(grad) + + return self.epsilon * get_normalized_vector(d)