diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py index c04a8d502a618a850271d64beb1dea26c5baac42..1ce05f122a3eb989c00a9a6f635606c76fcfcee0 100644 --- a/bob/learn/tensorflow/dataset/tfrecords.py +++ b/bob/learn/tensorflow/dataset/tfrecords.py @@ -1,6 +1,9 @@ from functools import partial import tensorflow as tf from . import append_image_augmentation, DEFAULT_FEATURE +import os +import logging +logger = logging.getLogger(__name__) def example_parser(serialized_example, feature, data_shape, data_type): @@ -389,3 +392,91 @@ def batch_data_and_labels_image_augmentation(tfrecord_filenames, features['key'] = key return features, labels + + +def describe_tf_record(tf_record_path, shape, batch_size=1): + """ + Describe the number of samples and the number of classes of a tf-record + + Parameters + ---------- + + tf_record_path: str + Base path containing your tf-record files + + shape: tuple + Shape inside of the tf-record + + batch_size: int + Well, batch size + + + Returns + ------- + + n_samples: int + Total number of samples + + n_classes: int + Total number of classes + + """ + + + tf_records = [os.path.join(tf_record_path, f) for f in os.listdir(tf_record_path)] + filename_queue = tf.train.string_input_producer(tf_records, num_epochs=1, name="input") + + feature = {'data': tf.FixedLenFeature([], tf.string), + 'label': tf.FixedLenFeature([], tf.int64), + 'key': tf.FixedLenFeature([], tf.string) + } + + # Define a reader and read the next record + reader = tf.TFRecordReader() + + _, serialized_example = reader.read(filename_queue) + + # Decode the record read by the reader + features = tf.parse_single_example(serialized_example, features=feature) + + # Convert the image data from string back to the numbers + image = tf.decode_raw(features['data'], tf.uint8) + + # Cast label data into int32 + label = tf.cast(features['label'], tf.int64) + img_name = tf.cast(features['key'], tf.string) + + # Reshape image data into the original shape + image = tf.reshape(image, shape) + + # Getting the batches in order + data_ph, label_ph, img_name_ph = tf.train.batch([image, label, img_name], batch_size=batch_size, + capacity=1000, num_threads=5, name="shuffle_batch") + + # Start the reading + session = tf.Session() + tf.local_variables_initializer().run(session=session) + tf.global_variables_initializer().run(session=session) + + # Preparing the batches + thread_pool = tf.train.Coordinator() + threads = tf.train.start_queue_runners(coord=thread_pool, sess=session) + + + logger.info("Counting in %s", tf_record_path) + labels = set() + counter = 0 + try: + while(True): + _, label, _ = session.run([data_ph, label_ph, img_name_ph]) + counter += len(label) + + for i in set(label): + labels.add(i) + + except tf.errors.OutOfRangeError: + pass + + thread_pool.request_stop() + return counter, len(labels) + diff --git a/bob/learn/tensorflow/script/db_to_tfrecords.py b/bob/learn/tensorflow/script/db_to_tfrecords.py index b24f78aa208d519f2fdb53c09ff55ac35ecdf5e0..b5948a85527388c3bc35a5c4bc17a87d86882f88 100644 --- a/bob/learn/tensorflow/script/db_to_tfrecords.py +++ b/bob/learn/tensorflow/script/db_to_tfrecords.py @@ -14,6 +14,9 @@ import tensorflow as tf from bob.io.base import create_directories_safe from bob.extension.scripts.click_helper import ( verbosity_option, ConfigCommand, ResourceOption, log_parameters) +import numpy +from bob.learn.tensorflow.dataset.tfrecords import describe_tf_record + logger = logging.getLogger(__name__) @@ -230,3 +233,39 @@ def db_to_tfrecords(samples, reader, output, shuffle, allow_failures, pass click.echo("The total size of the tfrecords file will be roughly " "{} bytes".format(_bytes2human(total_size))) + + +@click.command() +@click.argument( + 'tf-record-path', + nargs=1) +@click.argument( + 'shape', + type=int, + nargs=-1 + ) +@click.option( + '--batch-size', + help='Batch size', + show_default=True, + required=True, + default=1000 + ) +@verbosity_option(cls=ResourceOption) +def describe_tfrecord(tf_record_path, shape, batch_size, **kwargs): + ''' + Very often you have a tf-record file, or a set of them, and you have no idea + how many samples you have there. + Even worse, you have no idea how many classes you have. + + This click command will solve this thing for you by doing the following:: + + $ %(prog)s <tf-record-path> 182 182 3 + + ''' + n_samples, n_labels = describe_tf_record(tf_record_path, shape, batch_size) + click.echo("#############################################") + click.echo("Number of samples {0}".format(n_samples)) + click.echo("Number of labels {0}".format(n_labels)) + click.echo("#############################################") + diff --git a/bob/learn/tensorflow/test/test_db_to_tfrecords.py b/bob/learn/tensorflow/test/test_db_to_tfrecords.py index 5027fcfb5d107f7fa9fb9faf5628618947f34ef0..f990c06deff87535c0ca3ad8c64aa8f7114fd540 100644 --- a/bob/learn/tensorflow/test/test_db_to_tfrecords.py +++ b/bob/learn/tensorflow/test/test_db_to_tfrecords.py @@ -3,9 +3,10 @@ import shutil import pkg_resources import tempfile from click.testing import CliRunner - -from bob.learn.tensorflow.script.db_to_tfrecords import db_to_tfrecords - +import bob.io.base +from bob.learn.tensorflow.script.db_to_tfrecords import db_to_tfrecords, describe_tf_record +from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord + regenerate_reference = False dummy_config = pkg_resources.resource_filename( @@ -47,3 +48,24 @@ def test_db_to_tfrecords_size_estimate(): finally: shutil.rmtree(test_dir) + + +def test_tfrecord_counter(): + tfrecord_train = "./tf-train-test/train_mnist.tfrecord" + shape = (3136,) # I'm saving the thing as float + batch_size = 1000 + + try: + train_data, train_labels, validation_data, validation_labels = load_mnist() + bob.io.base.create_directories_safe(os.path.dirname(tfrecord_train)) + create_mnist_tfrecord( + tfrecord_train, train_data, train_labels, n_samples=6000) + + n_samples, n_labels = describe_tf_record(os.path.dirname(tfrecord_train), shape, batch_size) + + assert n_samples == 6000 + assert n_labels == 10 + + finally: + shutil.rmtree(os.path.dirname(tfrecord_train)) + diff --git a/setup.py b/setup.py index 8a0a51a7ac58ba078cd3894021d099a733bda25f..ec23c444958d1c06b99dffb83ebbf52f4249e6d7 100644 --- a/setup.py +++ b/setup.py @@ -52,11 +52,12 @@ setup( 'bob.learn.tensorflow.cli': [ 'compute_statistics = bob.learn.tensorflow.script.compute_statistics:compute_statistics', 'db_to_tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:db_to_tfrecords', + 'describe_tfrecord = bob.learn.tensorflow.script.db_to_tfrecords:describe_tfrecord', 'eval = bob.learn.tensorflow.script.eval:eval', 'predict_bio = bob.learn.tensorflow.script.predict_bio:predict_bio', 'train = bob.learn.tensorflow.script.train:train', 'train_and_evaluate = bob.learn.tensorflow.script.train_and_evaluate:train_and_evaluate', - 'style_transfer = bob.learn.tensorflow.script.style_transfer:style_transfer' + 'style_transfer = bob.learn.tensorflow.script.style_transfer:style_transfer', ], },