Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
23573ef0
Commit
23573ef0
authored
Jul 10, 2018
by
Tiago de Freitas Pereira
Browse files
Setting the reproducible run_config
parent
a15b4379
Changes
2
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/test/test_estimator_onegraph.py
View file @
23573ef0
...
...
@@ -38,6 +38,7 @@ def test_logitstrainer():
# Trainer logits
try
:
embedding_validation
=
False
_
,
run_config
,
_
,
_
,
_
=
reproducible
.
set_seed
()
trainer
=
Logits
(
model_dir
=
model_dir
,
architecture
=
dummy
,
...
...
@@ -45,7 +46,8 @@ def test_logitstrainer():
n_classes
=
10
,
loss_op
=
mean_cross_entropy_loss
,
embedding_validation
=
embedding_validation
,
validation_batch_size
=
validation_batch_size
)
validation_batch_size
=
validation_batch_size
,
config
=
run_config
)
run_logitstrainer_mnist
(
trainer
,
augmentation
=
True
)
finally
:
try
:
...
...
@@ -59,6 +61,7 @@ def test_logitstrainer():
def
test_logitstrainer_embedding
():
try
:
embedding_validation
=
True
_
,
run_config
,
_
,
_
,
_
=
reproducible
.
set_seed
()
trainer
=
Logits
(
model_dir
=
model_dir
,
architecture
=
dummy
,
...
...
@@ -66,8 +69,9 @@ def test_logitstrainer_embedding():
n_classes
=
10
,
loss_op
=
mean_cross_entropy_loss
,
embedding_validation
=
embedding_validation
,
validation_batch_size
=
validation_batch_size
)
validation_batch_size
=
validation_batch_size
,
config
=
run_config
)
run_logitstrainer_mnist
(
trainer
)
finally
:
try
:
...
...
@@ -81,7 +85,7 @@ def test_logitstrainer_embedding():
def
test_logitstrainer_centerloss
():
try
:
embedding_validation
=
False
run_config
=
tf
.
estimator
.
RunConfig
()
_
,
run_config
,
_
,
_
,
_
=
reproducible
.
set_seed
()
run_config
=
run_config
.
replace
(
save_checkpoints_steps
=
1000
)
trainer
=
LogitsCenterLoss
(
model_dir
=
model_dir
,
...
...
@@ -118,6 +122,7 @@ def test_logitstrainer_centerloss():
def
test_logitstrainer_centerloss_embedding
():
try
:
embedding_validation
=
True
_
,
run_config
,
_
,
_
,
_
=
reproducible
.
set_seed
()
trainer
=
LogitsCenterLoss
(
model_dir
=
model_dir
,
architecture
=
dummy
,
...
...
@@ -125,7 +130,9 @@ def test_logitstrainer_centerloss_embedding():
n_classes
=
10
,
embedding_validation
=
embedding_validation
,
validation_batch_size
=
validation_batch_size
,
factor
=
0.01
)
factor
=
0.01
,
config
=
run_config
)
run_logitstrainer_mnist
(
trainer
)
# Checking if the centers were updated
...
...
@@ -170,7 +177,7 @@ def run_logitstrainer_mnist(trainer, augmentation=False):
data_type
,
batch_size
,
random_flip
=
True
,
random_rotate
=
Tru
e
,
random_rotate
=
Fals
e
,
epochs
=
epochs
)
else
:
return
shuffle_data_and_labels
(
...
...
@@ -196,7 +203,6 @@ def run_logitstrainer_mnist(trainer, augmentation=False):
scaffold
=
tf
.
train
.
Scaffold
(),
summary_writer
=
tf
.
summary
.
FileWriter
(
model_dir
))
]
trainer
.
train
(
input_fn
,
steps
=
steps
,
hooks
=
hooks
)
if
not
trainer
.
embedding_validation
:
acc
=
trainer
.
evaluate
(
input_fn_validation
)
...
...
bob/learn/tensorflow/test/test_estimator_transfer.py
View file @
23573ef0
...
...
@@ -81,6 +81,7 @@ def dummy_adapted(inputs,
def
test_logitstrainer
():
# Trainer logits
try
:
_
,
run_config
,
_
,
_
,
_
=
reproducible
.
set_seed
()
embedding_validation
=
False
trainer
=
Logits
(
model_dir
=
model_dir
,
...
...
@@ -89,7 +90,9 @@ def test_logitstrainer():
n_classes
=
10
,
loss_op
=
mean_cross_entropy_loss
,
embedding_validation
=
embedding_validation
,
validation_batch_size
=
validation_batch_size
)
validation_batch_size
=
validation_batch_size
,
config
=
run_config
)
run_logitstrainer_mnist
(
trainer
,
augmentation
=
True
)
del
trainer
...
...
@@ -110,7 +113,9 @@ def test_logitstrainer():
loss_op
=
mean_cross_entropy_loss
,
embedding_validation
=
embedding_validation
,
validation_batch_size
=
validation_batch_size
,
extra_checkpoint
=
extra_checkpoint
)
extra_checkpoint
=
extra_checkpoint
,
config
=
run_config
)
run_logitstrainer_mnist
(
trainer
,
augmentation
=
True
)
...
...
@@ -129,7 +134,7 @@ def test_logitstrainer_center_loss():
# Trainer logits
try
:
embedding_validation
=
False
_
,
run_config
,
_
,
_
,
_
=
reproducible
.
set_seed
()
trainer
=
LogitsCenterLoss
(
model_dir
=
model_dir
,
architecture
=
dummy
,
...
...
@@ -137,7 +142,9 @@ def test_logitstrainer_center_loss():
n_classes
=
10
,
embedding_validation
=
embedding_validation
,
validation_batch_size
=
validation_batch_size
,
apply_moving_averages
=
False
)
apply_moving_averages
=
False
,
config
=
run_config
)
run_logitstrainer_mnist
(
trainer
,
augmentation
=
True
)
del
trainer
...
...
@@ -158,7 +165,9 @@ def test_logitstrainer_center_loss():
embedding_validation
=
embedding_validation
,
validation_batch_size
=
validation_batch_size
,
extra_checkpoint
=
extra_checkpoint
,
apply_moving_averages
=
False
)
apply_moving_averages
=
False
,
config
=
run_config
)
run_logitstrainer_mnist
(
trainer
,
augmentation
=
True
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment