Skip to content
Snippets Groups Projects
Commit 26fd7fcc authored by Zohreh MOSTAANI's avatar Zohreh MOSTAANI
Browse files

adapted part of the code for saving tf databases as tfrecords and loading them...

adapted part of the code for saving tf databases as tfrecords and loading them from tf records to tf2
parent 1e40a68b
No related branches found
No related tags found
1 merge request!84WIP: towards adapting the database folder to tf2
Pipeline #41084 failed
...@@ -32,6 +32,8 @@ def normalize_tfrecords_path(output): ...@@ -32,6 +32,8 @@ def normalize_tfrecords_path(output):
def bytes_feature(value): def bytes_feature(value):
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
...@@ -61,8 +63,8 @@ def dataset_to_tfrecord(dataset, output): ...@@ -61,8 +63,8 @@ def dataset_to_tfrecord(dataset, output):
output, json_output = tfrecord_name_and_json_name(output) output, json_output = tfrecord_name_and_json_name(output)
# dump the structure so that we can read it back # dump the structure so that we can read it back
meta = { meta = {
"output_types": repr(dataset.output_types), "output_types": repr(tf.compat.v1.data.get_output_types(dataset)),
"output_shapes": repr(dataset.output_shapes), "output_shapes": repr(tf.compat.v1.data.get_output_shapes(dataset)),
} }
with open(json_output, "w") as f: with open(json_output, "w") as f:
json.dump(meta, f) json.dump(meta, f)
...@@ -77,9 +79,9 @@ def dataset_to_tfrecord(dataset, output): ...@@ -77,9 +79,9 @@ def dataset_to_tfrecord(dataset, output):
return example_proto.SerializeToString() return example_proto.SerializeToString()
def tf_serialize_example(*args): def tf_serialize_example(*args):
args = tf.contrib.framework.nest.flatten(args) args = tf.nest.flatten(args)
args = [tf.serialize_tensor(f) for f in args] args = [tf.io.serialize_tensor(f) for f in args]
tf_string = tf.py_func(serialize_example_pyfunction, args, tf.string) tf_string = tf.py_function(serialize_example_pyfunction, args, tf.string)
return tf.reshape(tf_string, ()) # The result is a scalar return tf.reshape(tf_string, ()) # The result is a scalar
dataset = dataset.map(tf_serialize_example) dataset = dataset.map(tf_serialize_example)
...@@ -107,7 +109,7 @@ def dataset_from_tfrecord(tfrecord, num_parallel_reads=None): ...@@ -107,7 +109,7 @@ def dataset_from_tfrecord(tfrecord, num_parallel_reads=None):
A dataset that contains the data from the TFRecord file. A dataset that contains the data from the TFRecord file.
""" """
# these imports are needed so that eval can work # these imports are needed so that eval can work
from tensorflow import TensorShape, Dimension from tensorflow import TensorShape
if isinstance(tfrecord, str): if isinstance(tfrecord, str):
tfrecord = [tfrecord] tfrecord = [tfrecord]
...@@ -122,20 +124,20 @@ def dataset_from_tfrecord(tfrecord, num_parallel_reads=None): ...@@ -122,20 +124,20 @@ def dataset_from_tfrecord(tfrecord, num_parallel_reads=None):
meta = json.load(f) meta = json.load(f)
for k, v in meta.items(): for k, v in meta.items():
meta[k] = eval(v) meta[k] = eval(v)
output_types = tf.contrib.framework.nest.flatten(meta["output_types"]) output_types = tf.nest.flatten(meta["output_types"])
output_shapes = tf.contrib.framework.nest.flatten(meta["output_shapes"]) output_shapes = tf.nest.flatten(meta["output_shapes"])
feature_description = {} feature_description = {}
for i in range(len(output_types)): for i in range(len(output_types)):
key = f"feature{i}" key = f"feature{i}"
feature_description[key] = tf.FixedLenFeature([], tf.string) feature_description[key] = tf.io.FixedLenFeature([], tf.string)
def _parse_function(example_proto): def _parse_function(example_proto):
# Parse the input tf.Example proto using the dictionary above. # Parse the input tf.Example proto using the dictionary above.
args = tf.parse_single_example(example_proto, feature_description) args = tf.io.parse_single_example(example_proto, feature_description)
args = tf.contrib.framework.nest.flatten(args) args = tf.nest.flatten(args)
args = [tf.parse_tensor(v, t) for v, t in zip(args, output_types)] args = [tf.io.parse_tensor(v, t) for v, t in zip(args, output_types)]
args = [tf.reshape(v, s) for v, s in zip(args, output_shapes)] args = [tf.reshape(v, s) for v, s in zip(args, output_shapes)]
return tf.contrib.framework.nest.pack_sequence_as(meta["output_types"], args) return tf.nest.pack_sequence_as(meta["output_types"], args)
return raw_dataset.map(_parse_function) return raw_dataset.map(_parse_function)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment