diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py index 45201b88c0b732de17d87c8510f158cf653c4af9..d035c533cde81c026d1079a05f8123ef20b1ae05 100644 --- a/bob/learn/tensorflow/dataset/tfrecords.py +++ b/bob/learn/tensorflow/dataset/tfrecords.py @@ -32,6 +32,8 @@ def normalize_tfrecords_path(output): def bytes_feature(value): + if isinstance(value, type(tf.constant(0))): + value = value.numpy() # BytesList won't unpack a string from an EagerTensor. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) @@ -61,8 +63,8 @@ def dataset_to_tfrecord(dataset, output): output, json_output = tfrecord_name_and_json_name(output) # dump the structure so that we can read it back meta = { - "output_types": repr(dataset.output_types), - "output_shapes": repr(dataset.output_shapes), + "output_types": repr(tf.compat.v1.data.get_output_types(dataset)), + "output_shapes": repr(tf.compat.v1.data.get_output_shapes(dataset)), } with open(json_output, "w") as f: json.dump(meta, f) @@ -77,9 +79,9 @@ def dataset_to_tfrecord(dataset, output): return example_proto.SerializeToString() def tf_serialize_example(*args): - args = tf.contrib.framework.nest.flatten(args) - args = [tf.serialize_tensor(f) for f in args] - tf_string = tf.py_func(serialize_example_pyfunction, args, tf.string) + args = tf.nest.flatten(args) + args = [tf.io.serialize_tensor(f) for f in args] + tf_string = tf.py_function(serialize_example_pyfunction, args, tf.string) return tf.reshape(tf_string, ()) # The result is a scalar dataset = dataset.map(tf_serialize_example) @@ -107,7 +109,7 @@ def dataset_from_tfrecord(tfrecord, num_parallel_reads=None): A dataset that contains the data from the TFRecord file. """ # these imports are needed so that eval can work - from tensorflow import TensorShape, Dimension + from tensorflow import TensorShape if isinstance(tfrecord, str): tfrecord = [tfrecord] @@ -122,20 +124,20 @@ def dataset_from_tfrecord(tfrecord, num_parallel_reads=None): meta = json.load(f) for k, v in meta.items(): meta[k] = eval(v) - output_types = tf.contrib.framework.nest.flatten(meta["output_types"]) - output_shapes = tf.contrib.framework.nest.flatten(meta["output_shapes"]) + output_types = tf.nest.flatten(meta["output_types"]) + output_shapes = tf.nest.flatten(meta["output_shapes"]) feature_description = {} for i in range(len(output_types)): key = f"feature{i}" - feature_description[key] = tf.FixedLenFeature([], tf.string) + feature_description[key] = tf.io.FixedLenFeature([], tf.string) def _parse_function(example_proto): # Parse the input tf.Example proto using the dictionary above. - args = tf.parse_single_example(example_proto, feature_description) - args = tf.contrib.framework.nest.flatten(args) - args = [tf.parse_tensor(v, t) for v, t in zip(args, output_types)] + args = tf.io.parse_single_example(example_proto, feature_description) + args = tf.nest.flatten(args) + args = [tf.io.parse_tensor(v, t) for v, t in zip(args, output_types)] args = [tf.reshape(v, s) for v, s in zip(args, output_shapes)] - return tf.contrib.framework.nest.pack_sequence_as(meta["output_types"], args) + return tf.nest.pack_sequence_as(meta["output_types"], args) return raw_dataset.map(_parse_function)