db_to_tfrecords_config.py 491 Bytes
Newer Older
1
import tensorflow as tf
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
2

3
4
5
6
7
from bob.learn.tensorflow.data import dataset_using_generator

mnist = tf.keras.datasets.mnist

(x_train, y_train), (_, _) = mnist.load_data()
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
8
x_train, y_train = x_train[:10], y_train[:10]
9
samples = zip(tf.keras.backend.arange(len(x_train)), x_train, y_train)
10
11
12
13
14
15
16
17
18
19


def reader(sample):
    data = sample[1]
    label = sample[2]
    key = str(sample[0]).encode("utf-8")
    return ({"data": data, "key": key}, label)


dataset = dataset_using_generator(samples, reader)