From 26fd7fcc3b4eab2e78bd17b316797d8a89c9077e Mon Sep 17 00:00:00 2001
From: Zohreh MOSTAANI <zohreh.mostaani@idiap.ch>
Date: Wed, 8 Jul 2020 15:02:41 +0200
Subject: [PATCH] adapted part of the code for saving tf databases as tfrecords
 and loading them from tf records to tf2

---
 bob/learn/tensorflow/dataset/tfrecords.py | 28 ++++++++++++-----------
 1 file changed, 15 insertions(+), 13 deletions(-)

diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py
index 45201b88..d035c533 100644
--- a/bob/learn/tensorflow/dataset/tfrecords.py
+++ b/bob/learn/tensorflow/dataset/tfrecords.py
@@ -32,6 +32,8 @@ def normalize_tfrecords_path(output):
 
 
 def bytes_feature(value):
+    if isinstance(value, type(tf.constant(0))):
+        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
 
 
@@ -61,8 +63,8 @@ def dataset_to_tfrecord(dataset, output):
     output, json_output = tfrecord_name_and_json_name(output)
     # dump the structure so that we can read it back
     meta = {
-        "output_types": repr(dataset.output_types),
-        "output_shapes": repr(dataset.output_shapes),
+        "output_types": repr(tf.compat.v1.data.get_output_types(dataset)),
+        "output_shapes": repr(tf.compat.v1.data.get_output_shapes(dataset)),
     }
     with open(json_output, "w") as f:
         json.dump(meta, f)
@@ -77,9 +79,9 @@ def dataset_to_tfrecord(dataset, output):
         return example_proto.SerializeToString()
 
     def tf_serialize_example(*args):
-        args = tf.contrib.framework.nest.flatten(args)
-        args = [tf.serialize_tensor(f) for f in args]
-        tf_string = tf.py_func(serialize_example_pyfunction, args, tf.string)
+        args = tf.nest.flatten(args)
+        args = [tf.io.serialize_tensor(f) for f in args]
+        tf_string = tf.py_function(serialize_example_pyfunction, args, tf.string)
         return tf.reshape(tf_string, ())  # The result is a scalar
 
     dataset = dataset.map(tf_serialize_example)
@@ -107,7 +109,7 @@ def dataset_from_tfrecord(tfrecord, num_parallel_reads=None):
         A dataset that contains the data from the TFRecord file.
     """
     # these imports are needed so that eval can work
-    from tensorflow import TensorShape, Dimension
+    from tensorflow import TensorShape
 
     if isinstance(tfrecord, str):
         tfrecord = [tfrecord]
@@ -122,20 +124,20 @@ def dataset_from_tfrecord(tfrecord, num_parallel_reads=None):
         meta = json.load(f)
     for k, v in meta.items():
         meta[k] = eval(v)
-    output_types = tf.contrib.framework.nest.flatten(meta["output_types"])
-    output_shapes = tf.contrib.framework.nest.flatten(meta["output_shapes"])
+    output_types = tf.nest.flatten(meta["output_types"])
+    output_shapes = tf.nest.flatten(meta["output_shapes"])
     feature_description = {}
     for i in range(len(output_types)):
         key = f"feature{i}"
-        feature_description[key] = tf.FixedLenFeature([], tf.string)
+        feature_description[key] = tf.io.FixedLenFeature([], tf.string)
 
     def _parse_function(example_proto):
         # Parse the input tf.Example proto using the dictionary above.
-        args = tf.parse_single_example(example_proto, feature_description)
-        args = tf.contrib.framework.nest.flatten(args)
-        args = [tf.parse_tensor(v, t) for v, t in zip(args, output_types)]
+        args = tf.io.parse_single_example(example_proto, feature_description)
+        args = tf.nest.flatten(args)
+        args = [tf.io.parse_tensor(v, t) for v, t in zip(args, output_types)]
         args = [tf.reshape(v, s) for v, s in zip(args, output_shapes)]
-        return tf.contrib.framework.nest.pack_sequence_as(meta["output_types"], args)
+        return tf.nest.pack_sequence_as(meta["output_types"], args)
 
     return raw_dataset.map(_parse_function)
 
-- 
GitLab