Commit 1c3fdb10 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

improve logging for vat loss

parent a8bc6541
......@@ -108,20 +108,20 @@ class VATLoss:
architecture = partial(architecture, reuse=True)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
vat_loss = self.virtual_adversarial_loss(features, logits, architecture, mode)
tf.summary.scalar("vat_loss", vat_loss)
tf.summary.scalar("loss_VAT", vat_loss)
tf.add_to_collection(tf.GraphKeys.LOSSES, vat_loss)
if self.method == 'vat':
loss = vat_loss
elif self.method == 'vatent':
ent_loss = entropy_y_x(logits)
tf.summary.scalar("entropy_loss", ent_loss)
tf.summary.scalar("loss_entropy", ent_loss)
tf.add_to_collection(tf.GraphKeys.LOSSES, ent_loss)
loss = vat_loss + ent_loss
else:
raise ValueError
return loss
def virtual_adversarial_loss(self, features, logits, architecture, mode, name="vat_loss"):
def virtual_adversarial_loss(self, features, logits, architecture, mode, name="vat_loss_op"):
r_vadv = self.generate_virtual_adversarial_perturbation(features, logits, architecture, mode)
logit_p = tf.stop_gradient(logits)
adversarial_input = features + r_vadv
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment