Commit 9433478b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Compute VAT loss only during training.

parent 3bbc0005
......@@ -103,6 +103,8 @@ class VATLoss:
NotImplementedError
If self.method is not ``vat`` or ``vatent``.
"""
if mode != tf.estimator.ModeKeys.TRAIN:
return 0.
architecture = partial(architecture, reuse=True)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
vat_loss = self.virtual_adversarial_loss(features, logits, architecture, mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment