Skip to content
Snippets Groups Projects

WIP: towards adapting the database folder to tf2

Closed Zohreh MOSTAANI requested to merge towards_tf2 into master
1 file
+ 15
13
Compare changes
  • Side-by-side
  • Inline
@@ -32,6 +32,8 @@ def normalize_tfrecords_path(output):
def bytes_feature(value):
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
@@ -61,8 +63,8 @@ def dataset_to_tfrecord(dataset, output):
output, json_output = tfrecord_name_and_json_name(output)
# dump the structure so that we can read it back
meta = {
"output_types": repr(dataset.output_types),
"output_shapes": repr(dataset.output_shapes),
"output_types": repr(tf.compat.v1.data.get_output_types(dataset)),
"output_shapes": repr(tf.compat.v1.data.get_output_shapes(dataset)),
}
with open(json_output, "w") as f:
json.dump(meta, f)
@@ -77,9 +79,9 @@ def dataset_to_tfrecord(dataset, output):
return example_proto.SerializeToString()
def tf_serialize_example(*args):
args = tf.contrib.framework.nest.flatten(args)
args = [tf.serialize_tensor(f) for f in args]
tf_string = tf.py_func(serialize_example_pyfunction, args, tf.string)
args = tf.nest.flatten(args)
args = [tf.io.serialize_tensor(f) for f in args]
tf_string = tf.py_function(serialize_example_pyfunction, args, tf.string)
return tf.reshape(tf_string, ()) # The result is a scalar
dataset = dataset.map(tf_serialize_example)
@@ -107,7 +109,7 @@ def dataset_from_tfrecord(tfrecord, num_parallel_reads=None):
A dataset that contains the data from the TFRecord file.
"""
# these imports are needed so that eval can work
from tensorflow import TensorShape, Dimension
from tensorflow import TensorShape
if isinstance(tfrecord, str):
tfrecord = [tfrecord]
@@ -122,20 +124,20 @@ def dataset_from_tfrecord(tfrecord, num_parallel_reads=None):
meta = json.load(f)
for k, v in meta.items():
meta[k] = eval(v)
output_types = tf.contrib.framework.nest.flatten(meta["output_types"])
output_shapes = tf.contrib.framework.nest.flatten(meta["output_shapes"])
output_types = tf.nest.flatten(meta["output_types"])
output_shapes = tf.nest.flatten(meta["output_shapes"])
feature_description = {}
for i in range(len(output_types)):
key = f"feature{i}"
feature_description[key] = tf.FixedLenFeature([], tf.string)
feature_description[key] = tf.io.FixedLenFeature([], tf.string)
def _parse_function(example_proto):
# Parse the input tf.Example proto using the dictionary above.
args = tf.parse_single_example(example_proto, feature_description)
args = tf.contrib.framework.nest.flatten(args)
args = [tf.parse_tensor(v, t) for v, t in zip(args, output_types)]
args = tf.io.parse_single_example(example_proto, feature_description)
args = tf.nest.flatten(args)
args = [tf.io.parse_tensor(v, t) for v, t in zip(args, output_types)]
args = [tf.reshape(v, s) for v, s in zip(args, output_shapes)]
return tf.contrib.framework.nest.pack_sequence_as(meta["output_types"], args)
return tf.nest.pack_sequence_as(meta["output_types"], args)
return raw_dataset.map(_parse_function)
Loading