Skip to content
Snippets Groups Projects
Commit 23573ef0 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Setting the reproducible run_config

parent a15b4379
No related branches found
No related tags found
1 merge request!57Updates to the logits estimator
...@@ -38,6 +38,7 @@ def test_logitstrainer(): ...@@ -38,6 +38,7 @@ def test_logitstrainer():
# Trainer logits # Trainer logits
try: try:
embedding_validation = False embedding_validation = False
_, run_config,_,_,_ = reproducible.set_seed()
trainer = Logits( trainer = Logits(
model_dir=model_dir, model_dir=model_dir,
architecture=dummy, architecture=dummy,
...@@ -45,7 +46,8 @@ def test_logitstrainer(): ...@@ -45,7 +46,8 @@ def test_logitstrainer():
n_classes=10, n_classes=10,
loss_op=mean_cross_entropy_loss, loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation, embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size) validation_batch_size=validation_batch_size,
config=run_config)
run_logitstrainer_mnist(trainer, augmentation=True) run_logitstrainer_mnist(trainer, augmentation=True)
finally: finally:
try: try:
...@@ -59,6 +61,7 @@ def test_logitstrainer(): ...@@ -59,6 +61,7 @@ def test_logitstrainer():
def test_logitstrainer_embedding(): def test_logitstrainer_embedding():
try: try:
embedding_validation = True embedding_validation = True
_, run_config,_,_,_ = reproducible.set_seed()
trainer = Logits( trainer = Logits(
model_dir=model_dir, model_dir=model_dir,
architecture=dummy, architecture=dummy,
...@@ -66,8 +69,9 @@ def test_logitstrainer_embedding(): ...@@ -66,8 +69,9 @@ def test_logitstrainer_embedding():
n_classes=10, n_classes=10,
loss_op=mean_cross_entropy_loss, loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation, embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size) validation_batch_size=validation_batch_size,
config=run_config)
run_logitstrainer_mnist(trainer) run_logitstrainer_mnist(trainer)
finally: finally:
try: try:
...@@ -81,7 +85,7 @@ def test_logitstrainer_embedding(): ...@@ -81,7 +85,7 @@ def test_logitstrainer_embedding():
def test_logitstrainer_centerloss(): def test_logitstrainer_centerloss():
try: try:
embedding_validation = False embedding_validation = False
run_config = tf.estimator.RunConfig() _, run_config,_,_,_ = reproducible.set_seed()
run_config = run_config.replace(save_checkpoints_steps=1000) run_config = run_config.replace(save_checkpoints_steps=1000)
trainer = LogitsCenterLoss( trainer = LogitsCenterLoss(
model_dir=model_dir, model_dir=model_dir,
...@@ -118,6 +122,7 @@ def test_logitstrainer_centerloss(): ...@@ -118,6 +122,7 @@ def test_logitstrainer_centerloss():
def test_logitstrainer_centerloss_embedding(): def test_logitstrainer_centerloss_embedding():
try: try:
embedding_validation = True embedding_validation = True
_, run_config,_,_,_ = reproducible.set_seed()
trainer = LogitsCenterLoss( trainer = LogitsCenterLoss(
model_dir=model_dir, model_dir=model_dir,
architecture=dummy, architecture=dummy,
...@@ -125,7 +130,9 @@ def test_logitstrainer_centerloss_embedding(): ...@@ -125,7 +130,9 @@ def test_logitstrainer_centerloss_embedding():
n_classes=10, n_classes=10,
embedding_validation=embedding_validation, embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size, validation_batch_size=validation_batch_size,
factor=0.01) factor=0.01,
config=run_config
)
run_logitstrainer_mnist(trainer) run_logitstrainer_mnist(trainer)
# Checking if the centers were updated # Checking if the centers were updated
...@@ -170,7 +177,7 @@ def run_logitstrainer_mnist(trainer, augmentation=False): ...@@ -170,7 +177,7 @@ def run_logitstrainer_mnist(trainer, augmentation=False):
data_type, data_type,
batch_size, batch_size,
random_flip=True, random_flip=True,
random_rotate=True, random_rotate=False,
epochs=epochs) epochs=epochs)
else: else:
return shuffle_data_and_labels( return shuffle_data_and_labels(
...@@ -196,7 +203,6 @@ def run_logitstrainer_mnist(trainer, augmentation=False): ...@@ -196,7 +203,6 @@ def run_logitstrainer_mnist(trainer, augmentation=False):
scaffold=tf.train.Scaffold(), scaffold=tf.train.Scaffold(),
summary_writer=tf.summary.FileWriter(model_dir)) summary_writer=tf.summary.FileWriter(model_dir))
] ]
trainer.train(input_fn, steps=steps, hooks=hooks) trainer.train(input_fn, steps=steps, hooks=hooks)
if not trainer.embedding_validation: if not trainer.embedding_validation:
acc = trainer.evaluate(input_fn_validation) acc = trainer.evaluate(input_fn_validation)
......
...@@ -81,6 +81,7 @@ def dummy_adapted(inputs, ...@@ -81,6 +81,7 @@ def dummy_adapted(inputs,
def test_logitstrainer(): def test_logitstrainer():
# Trainer logits # Trainer logits
try: try:
_, run_config,_,_,_ = reproducible.set_seed()
embedding_validation = False embedding_validation = False
trainer = Logits( trainer = Logits(
model_dir=model_dir, model_dir=model_dir,
...@@ -89,7 +90,9 @@ def test_logitstrainer(): ...@@ -89,7 +90,9 @@ def test_logitstrainer():
n_classes=10, n_classes=10,
loss_op=mean_cross_entropy_loss, loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation, embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size) validation_batch_size=validation_batch_size,
config=run_config
)
run_logitstrainer_mnist(trainer, augmentation=True) run_logitstrainer_mnist(trainer, augmentation=True)
del trainer del trainer
...@@ -110,7 +113,9 @@ def test_logitstrainer(): ...@@ -110,7 +113,9 @@ def test_logitstrainer():
loss_op=mean_cross_entropy_loss, loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation, embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size, validation_batch_size=validation_batch_size,
extra_checkpoint=extra_checkpoint) extra_checkpoint=extra_checkpoint,
config=run_config
)
run_logitstrainer_mnist(trainer, augmentation=True) run_logitstrainer_mnist(trainer, augmentation=True)
...@@ -129,7 +134,7 @@ def test_logitstrainer_center_loss(): ...@@ -129,7 +134,7 @@ def test_logitstrainer_center_loss():
# Trainer logits # Trainer logits
try: try:
embedding_validation = False embedding_validation = False
_, run_config,_,_,_ = reproducible.set_seed()
trainer = LogitsCenterLoss( trainer = LogitsCenterLoss(
model_dir=model_dir, model_dir=model_dir,
architecture=dummy, architecture=dummy,
...@@ -137,7 +142,9 @@ def test_logitstrainer_center_loss(): ...@@ -137,7 +142,9 @@ def test_logitstrainer_center_loss():
n_classes=10, n_classes=10,
embedding_validation=embedding_validation, embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size, validation_batch_size=validation_batch_size,
apply_moving_averages=False) apply_moving_averages=False,
config=run_config
)
run_logitstrainer_mnist(trainer, augmentation=True) run_logitstrainer_mnist(trainer, augmentation=True)
del trainer del trainer
...@@ -158,7 +165,9 @@ def test_logitstrainer_center_loss(): ...@@ -158,7 +165,9 @@ def test_logitstrainer_center_loss():
embedding_validation=embedding_validation, embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size, validation_batch_size=validation_batch_size,
extra_checkpoint=extra_checkpoint, extra_checkpoint=extra_checkpoint,
apply_moving_averages=False) apply_moving_averages=False,
config=run_config
)
run_logitstrainer_mnist(trainer, augmentation=True) run_logitstrainer_mnist(trainer, augmentation=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment