Commit 23573ef0 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Setting the reproducible run_config

parent a15b4379
......@@ -38,6 +38,7 @@ def test_logitstrainer():
# Trainer logits
try:
embedding_validation = False
_, run_config,_,_,_ = reproducible.set_seed()
trainer = Logits(
model_dir=model_dir,
architecture=dummy,
......@@ -45,7 +46,8 @@ def test_logitstrainer():
n_classes=10,
loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size)
validation_batch_size=validation_batch_size,
config=run_config)
run_logitstrainer_mnist(trainer, augmentation=True)
finally:
try:
......@@ -59,6 +61,7 @@ def test_logitstrainer():
def test_logitstrainer_embedding():
try:
embedding_validation = True
_, run_config,_,_,_ = reproducible.set_seed()
trainer = Logits(
model_dir=model_dir,
architecture=dummy,
......@@ -66,7 +69,8 @@ def test_logitstrainer_embedding():
n_classes=10,
loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size)
validation_batch_size=validation_batch_size,
config=run_config)
run_logitstrainer_mnist(trainer)
finally:
......@@ -81,7 +85,7 @@ def test_logitstrainer_embedding():
def test_logitstrainer_centerloss():
try:
embedding_validation = False
run_config = tf.estimator.RunConfig()
_, run_config,_,_,_ = reproducible.set_seed()
run_config = run_config.replace(save_checkpoints_steps=1000)
trainer = LogitsCenterLoss(
model_dir=model_dir,
......@@ -118,6 +122,7 @@ def test_logitstrainer_centerloss():
def test_logitstrainer_centerloss_embedding():
try:
embedding_validation = True
_, run_config,_,_,_ = reproducible.set_seed()
trainer = LogitsCenterLoss(
model_dir=model_dir,
architecture=dummy,
......@@ -125,7 +130,9 @@ def test_logitstrainer_centerloss_embedding():
n_classes=10,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size,
factor=0.01)
factor=0.01,
config=run_config
)
run_logitstrainer_mnist(trainer)
# Checking if the centers were updated
......@@ -170,7 +177,7 @@ def run_logitstrainer_mnist(trainer, augmentation=False):
data_type,
batch_size,
random_flip=True,
random_rotate=True,
random_rotate=False,
epochs=epochs)
else:
return shuffle_data_and_labels(
......@@ -196,7 +203,6 @@ def run_logitstrainer_mnist(trainer, augmentation=False):
scaffold=tf.train.Scaffold(),
summary_writer=tf.summary.FileWriter(model_dir))
]
trainer.train(input_fn, steps=steps, hooks=hooks)
if not trainer.embedding_validation:
acc = trainer.evaluate(input_fn_validation)
......
......@@ -81,6 +81,7 @@ def dummy_adapted(inputs,
def test_logitstrainer():
# Trainer logits
try:
_, run_config,_,_,_ = reproducible.set_seed()
embedding_validation = False
trainer = Logits(
model_dir=model_dir,
......@@ -89,7 +90,9 @@ def test_logitstrainer():
n_classes=10,
loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size)
validation_batch_size=validation_batch_size,
config=run_config
)
run_logitstrainer_mnist(trainer, augmentation=True)
del trainer
......@@ -110,7 +113,9 @@ def test_logitstrainer():
loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size,
extra_checkpoint=extra_checkpoint)
extra_checkpoint=extra_checkpoint,
config=run_config
)
run_logitstrainer_mnist(trainer, augmentation=True)
......@@ -129,7 +134,7 @@ def test_logitstrainer_center_loss():
# Trainer logits
try:
embedding_validation = False
_, run_config,_,_,_ = reproducible.set_seed()
trainer = LogitsCenterLoss(
model_dir=model_dir,
architecture=dummy,
......@@ -137,7 +142,9 @@ def test_logitstrainer_center_loss():
n_classes=10,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size,
apply_moving_averages=False)
apply_moving_averages=False,
config=run_config
)
run_logitstrainer_mnist(trainer, augmentation=True)
del trainer
......@@ -158,7 +165,9 @@ def test_logitstrainer_center_loss():
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size,
extra_checkpoint=extra_checkpoint,
apply_moving_averages=False)
apply_moving_averages=False,
config=run_config
)
run_logitstrainer_mnist(trainer, augmentation=True)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment