From 1c3fdb100556c823260755dc83e0a3ad21b10b66 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Thu, 6 Jun 2019 17:25:20 +0200 Subject: [PATCH] improve logging for vat loss --- bob/learn/tensorflow/loss/vat.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bob/learn/tensorflow/loss/vat.py b/bob/learn/tensorflow/loss/vat.py index db1bc02f..b48f4f89 100644 --- a/bob/learn/tensorflow/loss/vat.py +++ b/bob/learn/tensorflow/loss/vat.py @@ -108,20 +108,20 @@ class VATLoss: architecture = partial(architecture, reuse=True) with tf.variable_scope(tf.get_variable_scope(), reuse=True): vat_loss = self.virtual_adversarial_loss(features, logits, architecture, mode) - tf.summary.scalar("vat_loss", vat_loss) + tf.summary.scalar("loss_VAT", vat_loss) tf.add_to_collection(tf.GraphKeys.LOSSES, vat_loss) if self.method == 'vat': loss = vat_loss elif self.method == 'vatent': ent_loss = entropy_y_x(logits) - tf.summary.scalar("entropy_loss", ent_loss) + tf.summary.scalar("loss_entropy", ent_loss) tf.add_to_collection(tf.GraphKeys.LOSSES, ent_loss) loss = vat_loss + ent_loss else: raise ValueError return loss - def virtual_adversarial_loss(self, features, logits, architecture, mode, name="vat_loss"): + def virtual_adversarial_loss(self, features, logits, architecture, mode, name="vat_loss_op"): r_vadv = self.generate_virtual_adversarial_perturbation(features, logits, architecture, mode) logit_p = tf.stop_gradient(logits) adversarial_input = features + r_vadv -- GitLab