From 3d80742de363a97ac41dba92f8a0d57b717d1d8c Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Mon, 27 May 2019 10:25:18 +0200
Subject: [PATCH] Add a parallel read option to dataset_from_tfrecords

---
 bob/learn/tensorflow/dataset/tfrecords.py | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py
index bfbbe10b..e1acd76b 100644
--- a/bob/learn/tensorflow/dataset/tfrecords.py
+++ b/bob/learn/tensorflow/dataset/tfrecords.py
@@ -87,7 +87,7 @@ def dataset_to_tfrecord(dataset, output):
     return writer.write(dataset)
 
 
-def dataset_from_tfrecord(tfrecord):
+def dataset_from_tfrecord(tfrecord, num_parallel_reads=None):
     """Reads TFRecords and returns a dataset.
     The TFRecord file must have been created using the :any:`dataset_to_tfrecord`
     function.
@@ -97,6 +97,9 @@ def dataset_from_tfrecord(tfrecord):
     tfrecord : str or list
         Path to the TFRecord file. Pass a list if you are sure several tfrecords need
         the same map function.
+    num_parallel_reads: (Optional.)
+        A `tf.int64` scalar representing the number of files to read in parallel.
+        Defaults to reading files sequentially.
 
     Returns
     -------
@@ -111,7 +114,9 @@ def dataset_from_tfrecord(tfrecord):
     tfrecord = [tfrecord_name_and_json_name(path) for path in tfrecord]
     json_output = tfrecord[0][1]
     tfrecord = [path[0] for path in tfrecord]
-    raw_dataset = tf.data.TFRecordDataset(tfrecord)
+    raw_dataset = tf.data.TFRecordDataset(
+        tfrecord, num_parallel_reads=num_parallel_reads
+    )
 
     with open(json_output) as f:
         meta = json.load(f)
-- 
GitLab