Skip to content
Snippets Groups Projects
Commit d47eae91 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add dataset_to_hdf5 command

parent f376738a
No related branches found
No related tags found
No related merge requests found
...@@ -268,3 +268,70 @@ def datasets_to_tfrecords(dataset, output, force, **kwargs): ...@@ -268,3 +268,70 @@ def datasets_to_tfrecords(dataset, output, force, **kwargs):
os.remove(json_output) os.remove(json_output)
raise raise
click.echo("Successfully wrote all files.") click.echo("Successfully wrote all files.")
@click.command(entry_point_group="bob.learn.tensorflow.config", cls=ConfigCommand)
@click.option(
"--dataset",
required=True,
cls=ResourceOption,
entry_point_group="bob.learn.tensorflow.dataset",
help="A tf.data.Dataset to be used.",
)
@click.option(
"--output", "-o", required=True, cls=ResourceOption, help="Name of the output file."
)
@click.option(
"--mean",
is_flag=True,
cls=ResourceOption,
help="If provided, the mean of data and labels will be saved in the hdf5 "
'file as well. You can access them in the "mean" groups.',
)
@verbosity_option(cls=ResourceOption)
def dataset_to_hdf5(dataset, output, mean, **kwargs):
"""Saves a tensorflow dataset into an HDF5 file
It is assumed that the dataset returns a tuple of (data, label, key) and
the dataset is not batched.
"""
log_parameters(logger)
data, label, key = dataset.make_one_shot_iterator().get_next()
sess = tf.Session()
extension = ".hdf5"
if not output.endswith(extension):
output += extension
create_directories_safe(os.path.dirname(output))
sample_count = 0
data_mean = 0.0
label_mean = 0.0
with HDF5File(output, "w") as f:
while True:
try:
d, l, k = sess.run([data, label, key])
group = "/{}".format(sample_count)
f.create_group(group)
f.cd(group)
f["data"] = d
f["label"] = l
f["key"] = k
sample_count += 1
if mean:
data_mean += (d - data_mean) / sample_count
label_mean += (l - label_mean) / sample_count
except tf.errors.OutOfRangeError:
break
if mean:
f.create_group("/mean")
f.cd("/mean")
f["data_mean"] = data_mean
f["label_mean"] = label_mean
click.echo(f"Wrote {sample_count} samples into the hdf5 file.")
...@@ -53,6 +53,7 @@ setup( ...@@ -53,6 +53,7 @@ setup(
'bob.learn.tensorflow.cli': [ 'bob.learn.tensorflow.cli': [
'cache-dataset = bob.learn.tensorflow.script.cache_dataset:cache_dataset', 'cache-dataset = bob.learn.tensorflow.script.cache_dataset:cache_dataset',
'compute-statistics = bob.learn.tensorflow.script.compute_statistics:compute_statistics', 'compute-statistics = bob.learn.tensorflow.script.compute_statistics:compute_statistics',
'dataset-to-hdf5 = bob.learn.tensorflow.script.db_to_tfrecords:dataset_to_hdf5',
'datasets-to-tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:datasets_to_tfrecords', 'datasets-to-tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:datasets_to_tfrecords',
'db-to-tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:db_to_tfrecords', 'db-to-tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:db_to_tfrecords',
'describe-tfrecord = bob.learn.tensorflow.script.db_to_tfrecords:describe_tfrecord', 'describe-tfrecord = bob.learn.tensorflow.script.db_to_tfrecords:describe_tfrecord',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment