Giving TFRecord support to pytorch data loaders

parent 9cb8e769
Pipeline #51882 failed with stage
in 13 minutes and 55 seconds
include README.rst buildout.cfg develop.cfg COPYING version.txt requirements.txt
include README.rst buildout.cfg develop.cfg COPYING version.txt requirements.txt ./bob/learn/pytorch/test/data/tfrecord_test.tfrecords
recursive-include doc *.py *.rst
from .tfrecord import TFRecordDataset, idiap_iterator
This diff is collapsed.
#!/usr/bin/env python
# encoding: utf-8
import torch
from import IterableDataset
import os
import struct
import numpy as np
from . import example_pb2
import glob
import functools
def decode_tfrecord(filename, start_offset=None, end_offset=None):
Yield bytes that corresponds to one line of a tfrecord
filename: str
TFRecord file name
start_offset: int
end_offset: int
if not os.path.exists(filename):
raise ValueError(f"File `{filename}` not found!")
tf_record_file = open(filename, "rb")
length_bytes = bytearray(8)
crc_bytes = bytearray(4)
datum_bytes = bytearray(1024 * 1024)
if start_offset is not None:
if end_offset is None:
end_offset = os.path.getsize(filename)
while tf_record_file.tell() < end_offset:
if tf_record_file.readinto(length_bytes) != 8:
raise RuntimeError("Failed to read the record size.")
if tf_record_file.readinto(crc_bytes) != 4:
raise RuntimeError("Failed to read the start token.")
(length,) = struct.unpack("<Q", length_bytes)
if length > len(datum_bytes):
datum_bytes = datum_bytes.zfill(int(length * 1.5))
datum_bytes_view = memoryview(datum_bytes)[:length]
if tf_record_file.readinto(datum_bytes_view) != length:
raise RuntimeError("Failed to read the record.")
if tf_record_file.readinto(crc_bytes) != 4:
raise RuntimeError("Failed to read the end token.")
yield datum_bytes_view
def idiap_iterator(filename, shape=(126, 126, 3)):
Returns an iterator that reads tfrecords written for `bob.learn.tensorflow`, whose format is the following:
feature = {
"data": bytes_list,
"label": _int64_list,
"key": bytes_list,
filename: str
Name of the tfrecord
iterator = decode_tfrecord(filename)
for f in iterator:
example = example_pb2.Example()
data = example.features.feature["data"].bytes_list.value[0]
data = np.frombuffer(data, dtype=np.uint8)
data = np.reshape(data, shape)
key = example.features.feature["key"].bytes_list.value[0]
label = example.features.feature["label"].int64_list.value[0]
yield {"data": data, "key": key, "label": label}
class TFRecordDataset(IterableDataset):
Generic DataSet that reads multiple tfrecord files into a `IterableDataset`
path: list
List of tfrecord paths
shape: tuple
Shape of the data stored in the tfrecord
Pointer to a function that yields a sample from the TFRecord
seed: int
Seed to the pseudo random number generator used to pick one of the `len(path)` iterators
def __init__(self, path, shape=(126, 126, 3), iterator_fn=idiap_iterator, seed=777):
super(TFRecordDataset, self).__init__()
self.path = glob.glob(path)
self.seed = seed
self.shape = shape
self.iterator_fn = iterator_fn
def __iter__(self):
# Creating one generator per tfrecord file
iterators = [
functools.partial(self.iterator_fn, filename=p, shape=self.shape)()
for p in self.path
# Iterating until the iterators are done
while iterators:
# Iterating over the iterators and picking one at each iteration
# Don't know how this is going to work in a multi-worker batching
it = np.random.choice(len(iterators))
yield next(iterators[it])
except StopIteration:
# If one of the iterators are finished, delete it from the list
if len(iterators) > 0:
del iterators[it]
......@@ -11,6 +11,8 @@ from import Dataset
import numpy
import torch
from torch.autograd import Variable
from bob.learn.pytorch.datasets.tfrecord import TFRecordDataset
import pkg_resources
def test_architectures():
......@@ -620,3 +622,17 @@ def test_extractors():
data = numpy.random.rand(8, 224, 224).astype("uint8")
output = extractor(data)
assert output.shape[0] == 1
def test_tfrecord_dataloader():
path = pkg_resources.resource_filename("bob.learn.pytorch.test", "data/*.tfrecords")
dataloader =, batch_size=32)
# Asserting 20 samples
batches = [f for f in dataloader]
assert len(batches) == 1
assert batches[0]["data"].numpy().shape == (20, 126, 126, 3)
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment