diff --git a/bob/learn/tensorflow/dataset/generator.py b/bob/learn/tensorflow/dataset/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..4988f50c7a0191836760c4966b0943cca912c606
--- /dev/null
+++ b/bob/learn/tensorflow/dataset/generator.py
@@ -0,0 +1,90 @@
+import six
+import tensorflow as tf
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class Generator:
+    """A generator class which wraps samples so that they can
+    be used with tf.data.Dataset.from_generator
+
+    Attributes
+    ----------
+    epoch : int
+        The number of epochs that have been passed so far.
+    multiple_samples : :obj:`bool`, optional
+        If true, it assumes that the bio database's samples actually contain
+        multiple samples. This is useful for when you want to for example treat
+        video databases as image databases.
+    reader : :obj:`object`, optional
+        A callable with the signature of ``data, label, key = reader(sample)``
+        which takes a sample and loads it.
+    samples : [:obj:`object`]
+        A list of samples to be given to ``reader`` to load the data.
+    output_types : (object, object, object)
+        The types of the returned samples.
+    output_shapes : ``(tf.TensorShape, tf.TensorShape, tf.TensorShape)``
+        The shapes of the returned samples.
+    """
+
+    def __init__(self, samples, reader, multiple_samples=False, **kwargs):
+        super(Generator, self).__init__(**kwargs)
+        self.reader = reader
+        self.samples = list(samples)
+        self.multiple_samples = multiple_samples
+        self.epoch = 0
+
+        # load one data to get its type and shape
+        dlk = self.reader(self.samples[0])
+        if self.multiple_samples:
+            try:
+                dlk = dlk[0]
+            except TypeError:
+                # if the data is a generator
+                dlk = six.next(dlk)
+        dataset = tf.data.Dataset.from_tensors(dlk)
+        self._output_types = dataset.output_types
+        self._output_shapes = dataset.output_shapes
+
+        logger.info(
+            "Initializing a dataset with %d %s and %s types and %s shapes",
+            len(self.samples),
+            "multi-samples" if self.multiple_samples else "samples",
+            self.output_types,
+            self.output_shapes,
+        )
+
+    @property
+    def output_types(self):
+        return self._output_types
+
+    @property
+    def output_shapes(self):
+        return self._output_shapes
+
+    def __call__(self):
+        """A generator function that when called will yield the samples.
+
+        Yields
+        ------
+        (data, label, key) : tuple
+            A tuple containing the data, label, and the key.
+        """
+        for sample in self.samples:
+            dlk = self.reader(sample)
+            if self.multiple_samples:
+                for sub_dlk in dlk:
+                    yield sub_dlk
+            else:
+                yield dlk
+        self.epoch += 1
+        logger.info("Elapsed %d epoch(s)", self.epoch)
+
+
+def dataset_using_generator(*args, **kwargs):
+    generator = Generator(*args, **kwargs)
+    dataset = tf.data.Dataset.from_generator(
+        generator, generator.output_types, generator.output_shapes
+    )
+    return dataset