Skip to content
Snippets Groups Projects

Created click command to count elements in the tf-record

Merged Tiago de Freitas Pereira requested to merge tf-record-counter into master
All threads resolved!
Files
2
@@ -14,6 +14,8 @@ import tensorflow as tf
@@ -14,6 +14,8 @@ import tensorflow as tf
from bob.io.base import create_directories_safe
from bob.io.base import create_directories_safe
from bob.extension.scripts.click_helper import (
from bob.extension.scripts.click_helper import (
verbosity_option, ConfigCommand, ResourceOption, log_parameters)
verbosity_option, ConfigCommand, ResourceOption, log_parameters)
 
import numpy
 
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
@@ -230,3 +232,91 @@ def db_to_tfrecords(samples, reader, output, shuffle, allow_failures,
@@ -230,3 +232,91 @@ def db_to_tfrecords(samples, reader, output, shuffle, allow_failures,
pass
pass
click.echo("The total size of the tfrecords file will be roughly "
click.echo("The total size of the tfrecords file will be roughly "
"{} bytes".format(_bytes2human(total_size)))
"{} bytes".format(_bytes2human(total_size)))
 
 
 
@click.command(
 
entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
 
@click.argument(
 
'tf-record-path')
 
@click.option(
 
'--shape',
 
help='Shape of the data in the tf-record',
 
show_default=True,
 
multiple=True,
 
required=True,
 
default=[182, 182, 3]
 
)
 
@click.option(
 
'--batch-size',
 
help='Batch size',
 
show_default=True,
 
required=True,
 
default=1
 
)
 
@verbosity_option(cls=ResourceOption)
 
def describe_tfrecord(tf_record_path, shape, batch_size, **kwargs):
 
'''
 
Very often you have a tf-record file, or a set of them, and you have no idea
 
how many samples you have there.
 
Even worse, you have no idea how many classes you have.
 
 
This click command will solve this thing for you by doing the following::
 
 
$ %(prog)s <tf-record-path> --shape 182 182 3
 
 
 
'''
 
tf_records = [os.path.join(tf_record_path, f) for f in os.listdir(tf_record_path)]
 
filename_queue = tf.train.string_input_producer(tf_records, num_epochs=1, name="input")
 
 
feature = {'data': tf.FixedLenFeature([], tf.string),
 
'label': tf.FixedLenFeature([], tf.int64),
 
'key': tf.FixedLenFeature([], tf.string)
 
}
 
 
# Define a reader and read the next record
 
reader = tf.TFRecordReader()
 
 
_, serialized_example = reader.read(filename_queue)
 
 
# Decode the record read by the reader
 
features = tf.parse_single_example(serialized_example, features=feature)
 
 
# Convert the image data from string back to the numbers
 
image = tf.decode_raw(features['data'], tf.uint8)
 
 
# Cast label data into int32
 
label = tf.cast(features['label'], tf.int64)
 
img_name = tf.cast(features['key'], tf.string)
 
 
# Reshape image data into the original shape
 
image = tf.reshape(image, shape)
 
 
# Getting the batches in order
 
data_ph, label_ph, img_name_ph = tf.train.batch([image, label, img_name], batch_size=batch_size,
 
capacity=1000, num_threads=5, name="shuffle_batch")
 
 
# Start the reading
 
session = tf.Session()
 
tf.local_variables_initializer().run(session=session)
 
tf.global_variables_initializer().run(session=session)
 
 
# Preparing the batches
 
thread_pool = tf.train.Coordinator()
 
threads = tf.train.start_queue_runners(coord=thread_pool, sess=session)
 
 
labels = []
 
logger.info("Counting in %s", tf_record_path)
 
try:
 
while(True):
 
_, label, _ = session.run([data_ph, label_ph, img_name_ph])
 
labels += numpy.ndarray.tolist(label)
 
except tf.errors.OutOfRangeError:
 
pass
 
 
print("Total samples: {0}".format(len(labels)))
 
labels = set(labels)
 
print("Total labels: {0}".format(len(labels)))
 
 
thread_pool.request_stop()
 
Loading