Skip to content
Snippets Groups Projects
Commit 6769b588 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

improve logging for vat loss

parent 5473a333
Branches
Tags
No related merge requests found
...@@ -108,20 +108,20 @@ class VATLoss: ...@@ -108,20 +108,20 @@ class VATLoss:
architecture = partial(architecture, reuse=True) architecture = partial(architecture, reuse=True)
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
vat_loss = self.virtual_adversarial_loss(features, logits, architecture, mode) vat_loss = self.virtual_adversarial_loss(features, logits, architecture, mode)
tf.summary.scalar("vat_loss", vat_loss) tf.summary.scalar("loss_VAT", vat_loss)
tf.add_to_collection(tf.GraphKeys.LOSSES, vat_loss) tf.add_to_collection(tf.GraphKeys.LOSSES, vat_loss)
if self.method == 'vat': if self.method == 'vat':
loss = vat_loss loss = vat_loss
elif self.method == 'vatent': elif self.method == 'vatent':
ent_loss = entropy_y_x(logits) ent_loss = entropy_y_x(logits)
tf.summary.scalar("entropy_loss", ent_loss) tf.summary.scalar("loss_entropy", ent_loss)
tf.add_to_collection(tf.GraphKeys.LOSSES, ent_loss) tf.add_to_collection(tf.GraphKeys.LOSSES, ent_loss)
loss = vat_loss + ent_loss loss = vat_loss + ent_loss
else: else:
raise ValueError raise ValueError
return loss return loss
def virtual_adversarial_loss(self, features, logits, architecture, mode, name="vat_loss"): def virtual_adversarial_loss(self, features, logits, architecture, mode, name="vat_loss_op"):
r_vadv = self.generate_virtual_adversarial_perturbation(features, logits, architecture, mode) r_vadv = self.generate_virtual_adversarial_perturbation(features, logits, architecture, mode)
logit_p = tf.stop_gradient(logits) logit_p = tf.stop_gradient(logits)
adversarial_input = features + r_vadv adversarial_input = features + r_vadv
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment