Commit 88a5f8b5 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Fix and enable tests

parent 103b120e
Pipeline #13440 failed with stages
in 15 minutes and 25 seconds
......@@ -36,7 +36,7 @@ The configuration files should have the following objects totally::
samples : a list of all samples that you want to write in the tfrecords
file. Whatever is inside this list is passed to the reader.
reader : a function with the signature of
`data, label, key = reader(sample)` which takes a sample and
``data, label, key = reader(sample)`` which takes a sample and
returns the loaded data, the label of the data, and a key which
is unique for every sample.
......@@ -91,7 +91,6 @@ from __future__ import print_function
import random
# import pkg_resources so that bob imports work properly:
import pkg_resources
import six
import tensorflow as tf
from bob.io.base import create_directories_safe
from bob.bio.base.utils import read_config_file
......
......@@ -205,7 +205,7 @@ def main(argv=None):
try:
pred_buffer = defaultdict(list)
for i, pred in enumerate(predictions):
key = pred['keys']
key = pred['key']
prob = pred.get('probabilities', pred.get('embeddings'))
pred_buffer[key].append(prob)
if i == 0:
......
......@@ -105,7 +105,7 @@ def main(argv=None):
try:
pred_buffer = defaultdict(list)
for i, pred in enumerate(predictions):
key = pred['keys']
key = pred['key']
prob = pred.get('probabilities', pred.get('embeddings'))
pred_buffer[key].append(prob)
if i == 0:
......
import os
from bob.bio.base.test.dummy.database import database
from bob.bio.base.test.dummy.preprocessor import preprocessor
from bob.bio.base.utils import read_original_data
groups = 'dev'
groups = ['dev']
files = database.all_files(groups=groups)
samples = database.all_files(groups=groups)
output = os.path.join('TEST_DIR', 'dev.tfrecords')
CLIENT_IDS = (str(f.client_id) for f in database.all_files(groups=groups))
CLIENT_IDS = list(set(CLIENT_IDS))
......@@ -15,8 +18,8 @@ def file_to_label(f):
def reader(biofile):
data = preprocessor.read_original_data(
data = read_original_data(
biofile, database.original_directory, database.original_extension)
label = file_to_label(biofile)
key = biofile.path
key = str(biofile.path)
return (data, label, key)
......@@ -4,7 +4,6 @@ import pkg_resources
import tempfile
from bob.learn.tensorflow.script.db_to_tfrecords import main as tfrecords
from bob.bio.base.script.verify import main as verify
regenerate_reference = False
......@@ -21,9 +20,7 @@ def test_verify_and_tfrecords():
parameters = [config_path]
try:
#verify(parameters)
#tfrecords(parameters)
pass
tfrecords(parameters)
# TODO: test if tfrecords are equal
# tfrecords_path = os.path.join(test_dir, 'sub_directory', 'dev.tfrecords')
......
......@@ -7,7 +7,6 @@ logging.getLogger("tensorflow").setLevel(logging.WARNING)
from bob.io.base.test_utils import datafile
from bob.learn.tensorflow.script.db_to_tfrecords import main as tfrecords
from bob.bio.base.script.verify import main as verify
from bob.learn.tensorflow.script.train_generic import main as train_generic
from bob.learn.tensorflow.script.eval_generic import main as eval_generic
......@@ -44,6 +43,9 @@ def architecture(images):
def model_fn(features, labels, mode, params, config):
key = features['key']
features = features['data']
logits = architecture(features)
predictions = {
......@@ -51,7 +53,8 @@ def model_fn(features, labels, mode, params, config):
"classes": tf.argmax(input=logits, axis=1),
# Add `softmax_tensor` to the graph. It is used for PREDICT and by the
# `logging_hook`.
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
"probabilities": tf.nn.softmax(logits, name="softmax_tensor"),
"key": key,
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
......@@ -82,9 +85,8 @@ def _create_tfrecord(test_dir):
config_path = os.path.join(test_dir, 'tfrecordconfig.py')
with open(dummy_tfrecord_config) as f, open(config_path, 'w') as f2:
f2.write(f.read().replace('TEST_DIR', test_dir))
#verify([config_path])
tfrecords([config_path])
return os.path.join(test_dir, 'sub_directory', 'dev.tfrecords')
return os.path.join(test_dir, 'dev.tfrecords')
def _create_checkpoint(tmpdir, model_dir, dummy_tfrecord):
......@@ -112,21 +114,21 @@ def test_eval_once():
eval_dir = os.path.join(model_dir, 'eval')
print('\nCreating a dummy tfrecord')
#dummy_tfrecord = _create_tfrecord(tmpdir)
dummy_tfrecord = _create_tfrecord(tmpdir)
print('Training a dummy network')
#_create_checkpoint(tmpdir, model_dir, dummy_tfrecord)
_create_checkpoint(tmpdir, model_dir, dummy_tfrecord)
print('Evaluating a dummy network')
#_eval(tmpdir, model_dir, dummy_tfrecord)
_eval(tmpdir, model_dir, dummy_tfrecord)
#evaluated_path = os.path.join(eval_dir, 'evaluated')
#assert os.path.exists(evaluated_path), evaluated_path
#with open(evaluated_path) as f:
# doc = f.read()
evaluated_path = os.path.join(eval_dir, 'evaluated')
assert os.path.exists(evaluated_path), evaluated_path
with open(evaluated_path) as f:
doc = f.read()
# assert '1' in doc, doc
# assert '100' in doc, doc
assert '1' in doc, doc
assert '100' in doc, doc
finally:
try:
shutil.rmtree(tmpdir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment