Skip to content
Snippets Groups Projects
Commit 1c3fdb10 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

improve logging for vat loss

parent a8bc6541
No related branches found
No related tags found
No related merge requests found
......@@ -108,20 +108,20 @@ class VATLoss:
architecture = partial(architecture, reuse=True)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
vat_loss = self.virtual_adversarial_loss(features, logits, architecture, mode)
tf.summary.scalar("vat_loss", vat_loss)
tf.summary.scalar("loss_VAT", vat_loss)
tf.add_to_collection(tf.GraphKeys.LOSSES, vat_loss)
if self.method == 'vat':
loss = vat_loss
elif self.method == 'vatent':
ent_loss = entropy_y_x(logits)
tf.summary.scalar("entropy_loss", ent_loss)
tf.summary.scalar("loss_entropy", ent_loss)
tf.add_to_collection(tf.GraphKeys.LOSSES, ent_loss)
loss = vat_loss + ent_loss
else:
raise ValueError
return loss
def virtual_adversarial_loss(self, features, logits, architecture, mode, name="vat_loss"):
def virtual_adversarial_loss(self, features, logits, architecture, mode, name="vat_loss_op"):
r_vadv = self.generate_virtual_adversarial_perturbation(features, logits, architecture, mode)
logit_p = tf.stop_gradient(logits)
adversarial_input = features + r_vadv
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment