Commit 9921e122 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add dataset_to_tfrecord and dataset_from_tfrecord

parent 09edcdf7
This diff is collapsed.
......@@ -4,126 +4,90 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
import random
import tempfile
import os
import sys
import logging
import click
import tensorflow as tf
from bob.io.base import create_directories_safe
from bob.io.base import create_directories_safe, HDF5File
from bob.extension.scripts.click_helper import (
verbosity_option, ConfigCommand, ResourceOption, log_parameters)
import numpy
from bob.learn.tensorflow.dataset.tfrecords import describe_tf_record
verbosity_option,
ConfigCommand,
ResourceOption,
log_parameters,
)
from bob.learn.tensorflow.dataset.tfrecords import (
describe_tf_record,
write_a_sample,
normalize_tfrecords_path,
tfrecord_name_and_json_name,
dataset_to_tfrecord,
)
from bob.learn.tensorflow.utils import bytes2human
logger = logging.getLogger(__name__)
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def write_a_sample(writer, data, label, key, feature=None,
size_estimate=False):
if feature is None:
feature = {
'data': bytes_feature(data.tostring()),
'label': int64_feature(label),
'key': bytes_feature(key)
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
example = example.SerializeToString()
if not size_estimate:
writer.write(example)
return sys.getsizeof(example)
def _bytes2human(n, format='%(value).1f %(symbol)s', symbols='customary'):
"""Convert n bytes into a human readable string based on format.
From: https://code.activestate.com/recipes/578019-bytes-to-human-human-to-
bytes-converter/
Author: Giampaolo Rodola' <g.rodola [AT] gmail [DOT] com>
License: MIT
symbols can be either "customary", "customary_ext", "iec" or "iec_ext",
see: http://goo.gl/kTQMs
"""
SYMBOLS = {
'customary': ('B', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y'),
'customary_ext': ('byte', 'kilo', 'mega', 'giga', 'tera', 'peta',
'exa', 'zetta', 'iotta'),
'iec': ('Bi', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi', 'Yi'),
'iec_ext': ('byte', 'kibi', 'mebi', 'gibi', 'tebi', 'pebi', 'exbi',
'zebi', 'yobi'),
}
n = int(n)
if n < 0:
raise ValueError("n < 0")
symbols = SYMBOLS[symbols]
prefix = {}
for i, s in enumerate(symbols[1:]):
prefix[s] = 1 << (i + 1) * 10
for symbol in reversed(symbols[1:]):
if n >= prefix[symbol]:
value = float(n) / prefix[symbol]
return format % locals()
return format % dict(symbol=symbols[0], value=n)
@click.command(
entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
@click.command(entry_point_group="bob.learn.tensorflow.config", cls=ConfigCommand)
@click.option(
'--samples',
"--samples",
required=True,
cls=ResourceOption,
help='A list of all samples that you want to write in the '
'tfrecords file. Whatever is inside this list is passed to '
'the reader.')
help="A list of all samples that you want to write in the "
"tfrecords file. Whatever is inside this list is passed to "
"the reader.",
)
@click.option(
'--reader',
"--reader",
required=True,
cls=ResourceOption,
help='a function with the signature of ``data, label, key = '
'reader(sample)`` which takes a sample and returns the '
'loaded data, the label of the data, and a key which is '
'unique for every sample.')
help="a function with the signature of ``data, label, key = "
"reader(sample)`` which takes a sample and returns the "
"loaded data, the label of the data, and a key which is "
"unique for every sample.",
)
@click.option(
'--output',
'-o',
required=True,
cls=ResourceOption,
help='Name of the output file.')
"--output", "-o", required=True, cls=ResourceOption, help="Name of the output file."
)
@click.option(
'--shuffle',
"--shuffle",
is_flag=True,
cls=ResourceOption,
help='If provided, it will shuffle the samples.')
help="If provided, it will shuffle the samples.",
)
@click.option(
'--allow-failures',
"--allow-failures",
is_flag=True,
cls=ResourceOption,
help='If provided, the samples which fail to load are ignored.')
help="If provided, the samples which fail to load are ignored.",
)
@click.option(
'--multiple-samples',
"--multiple-samples",
is_flag=True,
cls=ResourceOption,
help='If provided, it means that the data provided by reader contains '
'multiple samples with same label and path.')
help="If provided, it means that the data provided by reader contains "
"multiple samples with same label and path.",
)
@click.option(
'--size-estimate',
"--size-estimate",
is_flag=True,
cls=ResourceOption,
help='If given, will print the estimated file size instead of creating '
'the final tfrecord file.')
help="If given, will print the estimated file size instead of creating "
"the final tfrecord file.",
)
@verbosity_option(cls=ResourceOption)
def db_to_tfrecords(samples, reader, output, shuffle, allow_failures,
multiple_samples, size_estimate, **kwargs):
def db_to_tfrecords(
samples,
reader,
output,
shuffle,
allow_failures,
multiple_samples,
size_estimate,
**kwargs,
):
"""Converts Bio and PAD datasets to TFRecords file formats.
The best way to use this script is to send it to the io-big queue if you
......@@ -173,14 +137,13 @@ def db_to_tfrecords(samples, reader, output, shuffle, allow_failures,
key = biofile.path
return (data, label, key)
"""
log_parameters(logger, ignore=('samples', ))
log_parameters(logger, ignore=("samples",))
logger.debug("len(samples): %d", len(samples))
if size_estimate:
output = tempfile.NamedTemporaryFile(suffix='.tfrecords').name
output = tempfile.NamedTemporaryFile(suffix=".tfrecords").name
if not output.endswith(".tfrecords"):
output += ".tfrecords"
output = normalize_tfrecords_path(output)
if not size_estimate:
logger.info("Writing samples to `{}'".format(output))
......@@ -196,7 +159,7 @@ def db_to_tfrecords(samples, reader, output, shuffle, allow_failures,
logger.info("Shuffling the samples before writing ...")
random.shuffle(samples)
for i, sample in enumerate(samples):
logger.info('Processing file %d out of %d', i + 1, n_samples)
logger.info("Processing file %d out of %d", i + 1, n_samples)
data, label, key = reader(sample)
......@@ -205,55 +168,43 @@ def db_to_tfrecords(samples, reader, output, shuffle, allow_failures,
logger.debug("... Skipping `{0}`.".format(sample))
continue
else:
raise RuntimeError(
"Reading failed for `{0}`".format(sample))
raise RuntimeError("Reading failed for `{0}`".format(sample))
if multiple_samples:
for sample in data:
total_size += write_a_sample(
writer,
sample,
label,
key,
size_estimate=size_estimate)
writer, sample, label, key, size_estimate=size_estimate
)
sample_count += 1
else:
total_size += write_a_sample(
writer, data, label, key, size_estimate=size_estimate)
writer, data, label, key, size_estimate=size_estimate
)
sample_count += 1
if not size_estimate:
click.echo(
"Wrote {} samples into the tfrecords file.".format(sample_count))
click.echo("Wrote {} samples into the tfrecords file.".format(sample_count))
else:
# delete the empty tfrecords file
try:
os.remove(output)
except Exception:
pass
click.echo("The total size of the tfrecords file will be roughly "
"{} bytes".format(_bytes2human(total_size)))
click.echo(
"The total size of the tfrecords file will be roughly "
"{} bytes".format(bytes2human(total_size))
)
@click.command()
@click.argument(
'tf-record-path',
nargs=1)
@click.argument(
'shape',
type=int,
nargs=-1
)
@click.argument("tf-record-path", nargs=1)
@click.argument("shape", type=int, nargs=-1)
@click.option(
'--batch-size',
help='Batch size',
show_default=True,
required=True,
default=1000
"--batch-size", help="Batch size", show_default=True, required=True, default=1000
)
@verbosity_option(cls=ResourceOption)
def describe_tfrecord(tf_record_path, shape, batch_size, **kwargs):
'''
"""
Very often you have a tf-record file, or a set of them, and you have no
idea how many samples you have there. Even worse, you have no idea how many
classes you have.
......@@ -262,9 +213,58 @@ def describe_tfrecord(tf_record_path, shape, batch_size, **kwargs):
$ %(prog)s <tf-record-path> 182 182 3
'''
"""
n_samples, n_labels = describe_tf_record(tf_record_path, shape, batch_size)
click.echo("#############################################")
click.echo("Number of samples {0}".format(n_samples))
click.echo("Number of labels {0}".format(n_labels))
click.echo("#############################################")
@click.command(entry_point_group="bob.learn.tensorflow.config", cls=ConfigCommand)
@click.option(
"--dataset",
required=True,
cls=ResourceOption,
entry_point_group="bob.learn.tensorflow.dataset",
help="A tf.data.Dataset to be used.",
)
@click.option(
"--output", "-o", required=True, cls=ResourceOption, help="Name of the output file."
)
@click.option(
"--force",
"-f",
is_flag=True,
cls=ResourceOption,
help="Whether to overwrite existing files.",
)
@verbosity_option(cls=ResourceOption)
def datasets_to_tfrecords(dataset, output, force, **kwargs):
"""Converts tensorflow datasets into TFRecords.
Takes a list of datasets and outputs and writes each dataset into its output.
``datasets`` and ``outputs`` variables must be lists.
You can convert the written TFRecord files back to datasets using
:any:`bob.learn.tensorflow.dataset.tfrecords.dataset_from_tfrecord`.
To use this script with SGE, change your dataset and output based on the SGE_TASK_ID
environment variable in your config file.
"""
log_parameters(logger)
output, json_output = tfrecord_name_and_json_name(output)
if not force and os.path.isfile(output):
click.echo("Output file already exists: {}".format(output))
return
click.echo("Writing tfrecod to: {}".format(output))
with tf.Session() as sess:
os.makedirs(os.path.dirname(output), exist_ok=True)
try:
sess.run(dataset_to_tfrecord(dataset, output))
except Exception:
click.echo("Something failed. Deleting unfinished files.")
os.remove(output)
os.remove(json_output)
raise
click.echo("Successfully wrote all files.")
from bob.bio.base.test.dummy.database import database
from bob.bio.base.utils import read_original_data
from bob.learn.tensorflow.dataset.generator import dataset_using_generator
groups = ['dev']
groups = ["dev"]
samples = database.all_files(groups=groups)
......@@ -15,8 +16,13 @@ def file_to_label(f):
def reader(biofile):
data = read_original_data(biofile, database.original_directory,
database.original_extension)
data = read_original_data(
biofile, database.original_directory, database.original_extension
)
label = file_to_label(biofile)
key = str(biofile.path).encode("utf-8")
return (data, label, key)
dataset = dataset_using_generator(samples, reader)
datasets = [dataset]
......@@ -2,11 +2,16 @@ import os
import shutil
import pkg_resources
import tempfile
import tensorflow as tf
import numpy as np
from click.testing import CliRunner
from bob.io.base import create_directories_safe
from bob.learn.tensorflow.script.db_to_tfrecords import (
db_to_tfrecords, describe_tf_record)
db_to_tfrecords, describe_tf_record, datasets_to_tfrecords)
from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord
from bob.extension.scripts.click_helper import assert_click_runner_result
from bob.extension.config import load
from bob.learn.tensorflow.dataset.tfrecords import dataset_from_tfrecord
regenerate_reference = False
......@@ -14,6 +19,31 @@ dummy_config = pkg_resources.resource_filename(
'bob.learn.tensorflow', 'test/data/db_to_tfrecords_config.py')
def compare_datasets(ds1, ds2, sess=None):
if tf.executing_eagerly():
for values1, values2 in zip(ds1, ds2):
values1 = tf.contrib.framework.nest.flatten(values1)
values2 = tf.contrib.framework.nest.flatten(values2)
for v1, v2 in zip(values1, values2):
if not tf.reduce_all(tf.math.equal(v1, v2)):
return False
else:
ds1 = ds1.make_one_shot_iterator().get_next()
ds2 = ds2.make_one_shot_iterator().get_next()
while True:
try:
values1, values2 = sess.run([ds1, ds2])
except tf.errors.OutOfRangeError:
break
values1 = tf.contrib.framework.nest.flatten(values1)
values2 = tf.contrib.framework.nest.flatten(values2)
for v1, v2 in zip(values1, values2):
v1, v2 = np.asarray(v1), np.asarray(v2)
if not np.all(v1 == v2):
return False
return True
def test_db_to_tfrecords():
test_dir = tempfile.mkdtemp(prefix='bobtest_')
output_path = os.path.join(test_dir, 'dev.tfrecords')
......@@ -71,3 +101,19 @@ def test_tfrecord_counter():
finally:
shutil.rmtree(os.path.dirname(tfrecord_train))
def test_datasets_to_tfrecords():
runner = CliRunner()
with runner.isolated_filesystem():
output_path = './test'
args = (dummy_config, '--outputs', output_path)
result = runner.invoke(
datasets_to_tfrecords, args=args, standalone_mode=False)
assert_click_runner_result(result)
# read back the tfrecod
with tf.Session() as sess:
dataset2 = dataset_from_tfrecord(output_path)
dataset1 = load(
[dummy_config], attribute_name='dataset', entry_point_group='bob')
assert compare_datasets(dataset1, dataset2, sess)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment