Commit 9921e122 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Add dataset_to_tfrecord and dataset_from_tfrecord

parent 09edcdf7
This diff is collapsed.
from bob.bio.base.test.dummy.database import database
from bob.bio.base.utils import read_original_data
from bob.learn.tensorflow.dataset.generator import dataset_using_generator
groups = ['dev']
groups = ["dev"]
samples = database.all_files(groups=groups)
......@@ -15,8 +16,13 @@ def file_to_label(f):
def reader(biofile):
data = read_original_data(biofile, database.original_directory,
database.original_extension)
data = read_original_data(
biofile, database.original_directory, database.original_extension
)
label = file_to_label(biofile)
key = str(biofile.path).encode("utf-8")
return (data, label, key)
dataset = dataset_using_generator(samples, reader)
datasets = [dataset]
......@@ -2,11 +2,16 @@ import os
import shutil
import pkg_resources
import tempfile
import tensorflow as tf
import numpy as np
from click.testing import CliRunner
from bob.io.base import create_directories_safe
from bob.learn.tensorflow.script.db_to_tfrecords import (
db_to_tfrecords, describe_tf_record)
db_to_tfrecords, describe_tf_record, datasets_to_tfrecords)
from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord
from bob.extension.scripts.click_helper import assert_click_runner_result
from bob.extension.config import load
from bob.learn.tensorflow.dataset.tfrecords import dataset_from_tfrecord
regenerate_reference = False
......@@ -14,6 +19,31 @@ dummy_config = pkg_resources.resource_filename(
'bob.learn.tensorflow', 'test/data/db_to_tfrecords_config.py')
def compare_datasets(ds1, ds2, sess=None):
if tf.executing_eagerly():
for values1, values2 in zip(ds1, ds2):
values1 = tf.contrib.framework.nest.flatten(values1)
values2 = tf.contrib.framework.nest.flatten(values2)
for v1, v2 in zip(values1, values2):
if not tf.reduce_all(tf.math.equal(v1, v2)):
return False
else:
ds1 = ds1.make_one_shot_iterator().get_next()
ds2 = ds2.make_one_shot_iterator().get_next()
while True:
try:
values1, values2 = sess.run([ds1, ds2])
except tf.errors.OutOfRangeError:
break
values1 = tf.contrib.framework.nest.flatten(values1)
values2 = tf.contrib.framework.nest.flatten(values2)
for v1, v2 in zip(values1, values2):
v1, v2 = np.asarray(v1), np.asarray(v2)
if not np.all(v1 == v2):
return False
return True
def test_db_to_tfrecords():
test_dir = tempfile.mkdtemp(prefix='bobtest_')
output_path = os.path.join(test_dir, 'dev.tfrecords')
......@@ -71,3 +101,19 @@ def test_tfrecord_counter():
finally:
shutil.rmtree(os.path.dirname(tfrecord_train))
def test_datasets_to_tfrecords():
runner = CliRunner()
with runner.isolated_filesystem():
output_path = './test'
args = (dummy_config, '--outputs', output_path)
result = runner.invoke(
datasets_to_tfrecords, args=args, standalone_mode=False)
assert_click_runner_result(result)
# read back the tfrecod
with tf.Session() as sess:
dataset2 = dataset_from_tfrecord(output_path)
dataset1 = load(
[dummy_config], attribute_name='dataset', entry_point_group='bob')
assert compare_datasets(dataset1, dataset2, sess)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment