Skip to content
Snippets Groups Projects
Commit 9d6fa7d9 authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV
Browse files

datasets converted to tensorflow 2.0

parent 1fd8ec79
Branches
Tags
No related merge requests found
Pipeline #36119 failed
import tensorflow as tf
import tensorflow.compat.v1 as tf
import numpy
import os
import bob.io.base
......
import six
import tensorflow as tf
import tensorflow.compat.v1 as tf
import logging
logger = logging.getLogger(__name__)
......@@ -45,8 +45,8 @@ class Generator:
dlk = six.next(dlk)
# Creating a "fake" dataset just to get the types and shapes
dataset = tf.data.Dataset.from_tensors(dlk)
self._output_types = dataset.output_types
self._output_shapes = dataset.output_shapes
self._output_types = tf.data.get_output_types(dataset)
self._output_shapes = tf.data.get_output_shapes(dataset)
logger.info(
"Initializing a dataset with %d %s and %s types and %s shapes",
......
......@@ -10,7 +10,7 @@ import logging
import os
import sys
import tensorflow as tf
import tensorflow.compat.v1 as tf
from . import append_image_augmentation, DEFAULT_FEATURE
......@@ -61,8 +61,8 @@ def dataset_to_tfrecord(dataset, output):
output, json_output = tfrecord_name_and_json_name(output)
# dump the structure so that we can read it back
meta = {
"output_types": repr(dataset.output_types),
"output_shapes": repr(dataset.output_shapes),
"output_types": repr(tf.data.get_output_types(dataset)),
"output_shapes": repr(tf.data.get_output_shapes(dataset)),
}
with open(json_output, "w") as f:
json.dump(meta, f)
......@@ -77,9 +77,9 @@ def dataset_to_tfrecord(dataset, output):
return example_proto.SerializeToString()
def tf_serialize_example(*args):
args = tf.contrib.framework.nest.flatten(args)
args = tf.nest.flatten(args)
args = [tf.serialize_tensor(f) for f in args]
tf_string = tf.py_func(serialize_example_pyfunction, args, tf.string)
tf_string = tf.numpy_function(serialize_example_pyfunction, args, tf.string)
return tf.reshape(tf_string, ()) # The result is a scalar
dataset = dataset.map(tf_serialize_example)
......@@ -87,7 +87,7 @@ def dataset_to_tfrecord(dataset, output):
return writer.write(dataset)
def dataset_from_tfrecord(tfrecord):
def dataset_from_tfrecord(tfrecord, num_parallel_reads=None):
"""Reads TFRecords and returns a dataset.
The TFRecord file must have been created using the :any:`dataset_to_tfrecord`
function.
......@@ -97,6 +97,9 @@ def dataset_from_tfrecord(tfrecord):
tfrecord : str or list
Path to the TFRecord file. Pass a list if you are sure several tfrecords need
the same map function.
num_parallel_reads: (Optional.)
A `tf.int64` scalar representing the number of files to read in parallel.
Defaults to reading files sequentially.
Returns
-------
......@@ -104,21 +107,23 @@ def dataset_from_tfrecord(tfrecord):
A dataset that contains the data from the TFRecord file.
"""
# these imports are needed so that eval can work
from tensorflow import TensorShape, Dimension
from tensorflow.compat.v1 import TensorShape, Dimension
if isinstance(tfrecord, str):
tfrecord = [tfrecord]
tfrecord = [tfrecord_name_and_json_name(path) for path in tfrecord]
json_output = tfrecord[0][1]
tfrecord = [path[0] for path in tfrecord]
raw_dataset = tf.data.TFRecordDataset(tfrecord)
raw_dataset = tf.data.TFRecordDataset(
tfrecord, num_parallel_reads=num_parallel_reads
)
with open(json_output) as f:
meta = json.load(f)
for k, v in meta.items():
meta[k] = eval(v)
output_types = tf.contrib.framework.nest.flatten(meta["output_types"])
output_shapes = tf.contrib.framework.nest.flatten(meta["output_shapes"])
output_types = tf.nest.flatten(meta["output_types"])
output_shapes = tf.nest.flatten(meta["output_shapes"])
feature_description = {}
for i in range(len(output_types)):
key = f"feature{i}"
......@@ -127,10 +132,10 @@ def dataset_from_tfrecord(tfrecord):
def _parse_function(example_proto):
# Parse the input tf.Example proto using the dictionary above.
args = tf.parse_single_example(example_proto, feature_description)
args = tf.contrib.framework.nest.flatten(args)
args = tf.nest.flatten(args)
args = [tf.parse_tensor(v, t) for v, t in zip(args, output_types)]
args = [tf.reshape(v, s) for v, s in zip(args, output_shapes)]
return tf.contrib.framework.nest.pack_sequence_as(meta["output_types"], args)
return tf.nest.pack_sequence_as(meta["output_types"], args)
return raw_dataset.map(_parse_function)
......
......@@ -9,7 +9,7 @@ import os
import random
import tempfile
import click
import tensorflow as tf
import tensorflow.compat.v1 as tf
from bob.io.base import create_directories_safe, HDF5File
from bob.extension.scripts.click_helper import (
verbosity_option,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment