Commit 9433478b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Compute VAT loss only during training.

parent 3bbc0005
...@@ -103,6 +103,8 @@ class VATLoss: ...@@ -103,6 +103,8 @@ class VATLoss:
NotImplementedError NotImplementedError
If self.method is not ``vat`` or ``vatent``. If self.method is not ``vat`` or ``vatent``.
""" """
if mode != tf.estimator.ModeKeys.TRAIN:
return 0.
architecture = partial(architecture, reuse=True) architecture = partial(architecture, reuse=True)
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
vat_loss = self.virtual_adversarial_loss(features, logits, architecture, mode) vat_loss = self.virtual_adversarial_loss(features, logits, architecture, mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment