diff --git a/bob/learn/tensorflow/script/db_to_tfrecords.py b/bob/learn/tensorflow/script/db_to_tfrecords.py index fbef7a0013f1f6c62e2ca12b01050d771f103207..7460f7c664ec2726bdeae2bec5ca4bcfc0e11d4a 100644 --- a/bob/learn/tensorflow/script/db_to_tfrecords.py +++ b/bob/learn/tensorflow/script/db_to_tfrecords.py @@ -268,3 +268,70 @@ def datasets_to_tfrecords(dataset, output, force, **kwargs): os.remove(json_output) raise click.echo("Successfully wrote all files.") + + +@click.command(entry_point_group="bob.learn.tensorflow.config", cls=ConfigCommand) +@click.option( + "--dataset", + required=True, + cls=ResourceOption, + entry_point_group="bob.learn.tensorflow.dataset", + help="A tf.data.Dataset to be used.", +) +@click.option( + "--output", "-o", required=True, cls=ResourceOption, help="Name of the output file." +) +@click.option( + "--mean", + is_flag=True, + cls=ResourceOption, + help="If provided, the mean of data and labels will be saved in the hdf5 " + 'file as well. You can access them in the "mean" groups.', +) +@verbosity_option(cls=ResourceOption) +def dataset_to_hdf5(dataset, output, mean, **kwargs): + """Saves a tensorflow dataset into an HDF5 file + + It is assumed that the dataset returns a tuple of (data, label, key) and + the dataset is not batched. + """ + log_parameters(logger) + + data, label, key = dataset.make_one_shot_iterator().get_next() + + sess = tf.Session() + + extension = ".hdf5" + + if not output.endswith(extension): + output += extension + + create_directories_safe(os.path.dirname(output)) + + sample_count = 0 + data_mean = 0.0 + label_mean = 0.0 + + with HDF5File(output, "w") as f: + while True: + try: + d, l, k = sess.run([data, label, key]) + group = "/{}".format(sample_count) + f.create_group(group) + f.cd(group) + f["data"] = d + f["label"] = l + f["key"] = k + sample_count += 1 + if mean: + data_mean += (d - data_mean) / sample_count + label_mean += (l - label_mean) / sample_count + except tf.errors.OutOfRangeError: + break + if mean: + f.create_group("/mean") + f.cd("/mean") + f["data_mean"] = data_mean + f["label_mean"] = label_mean + + click.echo(f"Wrote {sample_count} samples into the hdf5 file.") diff --git a/setup.py b/setup.py index 559961b297969635c1a7c8109598a0cd94cad563..2b1997fbdcb0bf7d8bbe19fa0280a3f20d01f0e1 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ setup( 'bob.learn.tensorflow.cli': [ 'cache-dataset = bob.learn.tensorflow.script.cache_dataset:cache_dataset', 'compute-statistics = bob.learn.tensorflow.script.compute_statistics:compute_statistics', + 'dataset-to-hdf5 = bob.learn.tensorflow.script.db_to_tfrecords:dataset_to_hdf5', 'datasets-to-tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:datasets_to_tfrecords', 'db-to-tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:db_to_tfrecords', 'describe-tfrecord = bob.learn.tensorflow.script.db_to_tfrecords:describe_tfrecord',