Skip to content
Snippets Groups Projects
Commit 5ba976b8 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Fix tests on moving average

parent bd586c1c
No related branches found
No related tags found
1 merge request!68Several changes
Pipeline #24651 failed
......@@ -221,15 +221,8 @@ def test_moving_average_trainer():
# define a fixed input data
# train the same network with the same initialization
# evaluate it
# did it change? no -> good
# train it again with moving average
# Is it different from no moving average? yes -> good
no_moving_average = {'accuracy': 0.128, 'loss': 5.3208413, 'global_step': 188}
with_moving_average = {'accuracy': 0.14,
'loss': 2.3772557, 'global_step': 188}
resume_moving_average = {'accuracy': 0.172,
'loss': 2.2985911, 'global_step': 376}
# train and evaluate it again with moving average
# Accuracy should be lower when moving average is on
try:
# Creating tf records for mnist
......@@ -258,8 +251,7 @@ def test_moving_average_trainer():
validation_batch_size,
epochs=1)
from bob.learn.tensorflow.network.SimpleCNN import (
new_architecture as architecture)
from bob.learn.tensorflow.network.Dummy import dummy as architecture
run_config = reproducible.set_seed(183, 183)[1]
run_config = run_config.replace(save_checkpoints_steps=2000)
......@@ -275,29 +267,32 @@ def test_moving_average_trainer():
apply_moving_averages=apply_moving_averages,
)
def _evaluate(estimator, oracle, delete=True):
def _evaluate(estimator, delete=True):
try:
estimator.train(input_fn)
evaluations = estimator.evaluate(input_fn_validation)
finally:
if delete:
shutil.rmtree(estimator.model_dir, ignore_errors=True)
for k in ('accuracy', 'loss', 'global_step'):
assert numpy.allclose(evaluations[k], oracle[k]), \
(k, evaluations, oracle)
return evaluations
estimator = _estimator(False)
_evaluate(estimator, no_moving_average, delete=True)
evaluations = _evaluate(estimator, delete=True)
no_moving_average_acc = evaluations['accuracy']
# same as above with moving average
estimator = _estimator(True)
_evaluate(estimator, with_moving_average, delete=False)
evaluations = _evaluate(estimator, delete=False)
with_moving_average_acc = evaluations['accuracy']
assert no_moving_average_acc > with_moving_average_acc, \
(no_moving_average_acc, with_moving_average_acc)
# Can it resume training?
del estimator
tf.reset_default_graph()
estimator = _estimator(True)
_evaluate(estimator, resume_moving_average, delete=True)
_evaluate(estimator, delete=True)
finally:
try:
......
......@@ -38,7 +38,8 @@ def test_logitstrainer_images():
n_classes=10,
loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size)
validation_batch_size=validation_batch_size,
apply_moving_averages=False)
run_logitstrainer_images(trainer)
finally:
try:
......@@ -94,12 +95,8 @@ def run_logitstrainer_images(trainer):
trainer.train(input_fn, steps=steps, hooks=hooks)
if not trainer.embedding_validation:
acc = trainer.evaluate(input_fn_validation)
assert acc['accuracy'] > 0.30
else:
acc = trainer.evaluate(input_fn_validation)
assert acc['accuracy'] > 0.30
acc = trainer.evaluate(input_fn_validation)
assert acc['accuracy'] > 0.30, acc['accuracy']
# Cleaning up
tf.reset_default_graph()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment