diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py
index bfbbe10b97b9831c3a7a900599e9bde45bcbec7d..e1acd76b23ae07f66320df381c4982f15df6f429 100644
--- a/bob/learn/tensorflow/dataset/tfrecords.py
+++ b/bob/learn/tensorflow/dataset/tfrecords.py
@@ -87,7 +87,7 @@ def dataset_to_tfrecord(dataset, output):
return writer.write(dataset)
-def dataset_from_tfrecord(tfrecord):
+def dataset_from_tfrecord(tfrecord, num_parallel_reads=None):
"""Reads TFRecords and returns a dataset.
The TFRecord file must have been created using the :any:`dataset_to_tfrecord`
function.
@@ -97,6 +97,9 @@ def dataset_from_tfrecord(tfrecord):
tfrecord : str or list
Path to the TFRecord file. Pass a list if you are sure several tfrecords need
the same map function.
+ num_parallel_reads: (Optional.)
+ A `tf.int64` scalar representing the number of files to read in parallel.
+ Defaults to reading files sequentially.
Returns
-------
@@ -111,7 +114,9 @@ def dataset_from_tfrecord(tfrecord):
tfrecord = [tfrecord_name_and_json_name(path) for path in tfrecord]
json_output = tfrecord[0][1]
tfrecord = [path[0] for path in tfrecord]
- raw_dataset = tf.data.TFRecordDataset(tfrecord)
+ raw_dataset = tf.data.TFRecordDataset(
+ tfrecord, num_parallel_reads=num_parallel_reads
+ )
with open(json_output) as f:
meta = json.load(f)