Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
88a5f8b5
Commit
88a5f8b5
authored
Oct 24, 2017
by
Amir MOHAMMADI
Browse files
Fix and enable tests
parent
103b120e
Pipeline
#13440
failed with stages
in 15 minutes and 25 seconds
Changes
6
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/script/db_to_tfrecords.py
View file @
88a5f8b5
...
...
@@ -36,7 +36,7 @@ The configuration files should have the following objects totally::
samples : a list of all samples that you want to write in the tfrecords
file. Whatever is inside this list is passed to the reader.
reader : a function with the signature of
`data, label, key = reader(sample)` which takes a sample and
`
`data, label, key = reader(sample)`
`
which takes a sample and
returns the loaded data, the label of the data, and a key which
is unique for every sample.
...
...
@@ -91,7 +91,6 @@ from __future__ import print_function
import
random
# import pkg_resources so that bob imports work properly:
import
pkg_resources
import
six
import
tensorflow
as
tf
from
bob.io.base
import
create_directories_safe
from
bob.bio.base.utils
import
read_config_file
...
...
bob/learn/tensorflow/script/predict_bio.py
View file @
88a5f8b5
...
...
@@ -205,7 +205,7 @@ def main(argv=None):
try
:
pred_buffer
=
defaultdict
(
list
)
for
i
,
pred
in
enumerate
(
predictions
):
key
=
pred
[
'key
s
'
]
key
=
pred
[
'key'
]
prob
=
pred
.
get
(
'probabilities'
,
pred
.
get
(
'embeddings'
))
pred_buffer
[
key
].
append
(
prob
)
if
i
==
0
:
...
...
bob/learn/tensorflow/script/predict_generic.py
View file @
88a5f8b5
...
...
@@ -105,7 +105,7 @@ def main(argv=None):
try
:
pred_buffer
=
defaultdict
(
list
)
for
i
,
pred
in
enumerate
(
predictions
):
key
=
pred
[
'key
s
'
]
key
=
pred
[
'key'
]
prob
=
pred
.
get
(
'probabilities'
,
pred
.
get
(
'embeddings'
))
pred_buffer
[
key
].
append
(
prob
)
if
i
==
0
:
...
...
bob/learn/tensorflow/test/data/dummy_verify_config.py
View file @
88a5f8b5
import
os
from
bob.bio.base.test.dummy.database
import
database
from
bob.bio.base.
test.dummy.preprocessor
import
p
re
processor
from
bob.bio.base.
utils
import
re
ad_original_data
groups
=
'dev'
groups
=
[
'dev'
]
files
=
database
.
all_files
(
groups
=
groups
)
samples
=
database
.
all_files
(
groups
=
groups
)
output
=
os
.
path
.
join
(
'TEST_DIR'
,
'dev.tfrecords'
)
CLIENT_IDS
=
(
str
(
f
.
client_id
)
for
f
in
database
.
all_files
(
groups
=
groups
))
CLIENT_IDS
=
list
(
set
(
CLIENT_IDS
))
...
...
@@ -15,8 +18,8 @@ def file_to_label(f):
def
reader
(
biofile
):
data
=
preprocessor
.
read_original_data
(
data
=
read_original_data
(
biofile
,
database
.
original_directory
,
database
.
original_extension
)
label
=
file_to_label
(
biofile
)
key
=
biofile
.
path
key
=
str
(
biofile
.
path
)
return
(
data
,
label
,
key
)
bob/learn/tensorflow/test/test_db_to_tfrecords.py
View file @
88a5f8b5
...
...
@@ -4,7 +4,6 @@ import pkg_resources
import
tempfile
from
bob.learn.tensorflow.script.db_to_tfrecords
import
main
as
tfrecords
from
bob.bio.base.script.verify
import
main
as
verify
regenerate_reference
=
False
...
...
@@ -21,9 +20,7 @@ def test_verify_and_tfrecords():
parameters
=
[
config_path
]
try
:
#verify(parameters)
#tfrecords(parameters)
pass
tfrecords
(
parameters
)
# TODO: test if tfrecords are equal
# tfrecords_path = os.path.join(test_dir, 'sub_directory', 'dev.tfrecords')
...
...
bob/learn/tensorflow/test/test_estimator_scripts.py
View file @
88a5f8b5
...
...
@@ -7,7 +7,6 @@ logging.getLogger("tensorflow").setLevel(logging.WARNING)
from
bob.io.base.test_utils
import
datafile
from
bob.learn.tensorflow.script.db_to_tfrecords
import
main
as
tfrecords
from
bob.bio.base.script.verify
import
main
as
verify
from
bob.learn.tensorflow.script.train_generic
import
main
as
train_generic
from
bob.learn.tensorflow.script.eval_generic
import
main
as
eval_generic
...
...
@@ -44,6 +43,9 @@ def architecture(images):
def model_fn(features, labels, mode, params, config):
key = features['key']
features = features['data']
logits = architecture(features)
predictions = {
...
...
@@ -51,7 +53,8 @@ def model_fn(features, labels, mode, params, config):
"classes": tf.argmax(input=logits, axis=1),
# Add `softmax_tensor` to the graph. It is used for PREDICT and by the
# `logging_hook`.
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
"probabilities": tf.nn.softmax(logits, name="softmax_tensor"),
"key": key,
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
...
...
@@ -82,9 +85,8 @@ def _create_tfrecord(test_dir):
config_path
=
os
.
path
.
join
(
test_dir
,
'tfrecordconfig.py'
)
with
open
(
dummy_tfrecord_config
)
as
f
,
open
(
config_path
,
'w'
)
as
f2
:
f2
.
write
(
f
.
read
().
replace
(
'TEST_DIR'
,
test_dir
))
#verify([config_path])
tfrecords
([
config_path
])
return
os
.
path
.
join
(
test_dir
,
'sub_directory'
,
'dev.tfrecords'
)
return
os
.
path
.
join
(
test_dir
,
'dev.tfrecords'
)
def
_create_checkpoint
(
tmpdir
,
model_dir
,
dummy_tfrecord
):
...
...
@@ -112,21 +114,21 @@ def test_eval_once():
eval_dir
=
os
.
path
.
join
(
model_dir
,
'eval'
)
print
(
'
\n
Creating a dummy tfrecord'
)
#
dummy_tfrecord = _create_tfrecord(tmpdir)
dummy_tfrecord
=
_create_tfrecord
(
tmpdir
)
print
(
'Training a dummy network'
)
#
_create_checkpoint(tmpdir, model_dir, dummy_tfrecord)
_create_checkpoint
(
tmpdir
,
model_dir
,
dummy_tfrecord
)
print
(
'Evaluating a dummy network'
)
#
_eval(tmpdir, model_dir, dummy_tfrecord)
_eval
(
tmpdir
,
model_dir
,
dummy_tfrecord
)
#
evaluated_path = os.path.join(eval_dir, 'evaluated')
#
assert os.path.exists(evaluated_path), evaluated_path
#
with open(evaluated_path) as f:
#
doc = f.read()
evaluated_path
=
os
.
path
.
join
(
eval_dir
,
'evaluated'
)
assert
os
.
path
.
exists
(
evaluated_path
),
evaluated_path
with
open
(
evaluated_path
)
as
f
:
doc
=
f
.
read
()
#
assert '1' in doc, doc
#
assert '100' in doc, doc
assert
'1'
in
doc
,
doc
assert
'100'
in
doc
,
doc
finally
:
try
:
shutil
.
rmtree
(
tmpdir
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment