Skip to content
Snippets Groups Projects

Created click command to count elements in the tf-record

Merged Tiago de Freitas Pereira requested to merge tf-record-counter into master
4 files
+ 152
4
Compare changes
  • Side-by-side
  • Inline
Files
4
from functools import partial
from functools import partial
import tensorflow as tf
import tensorflow as tf
from . import append_image_augmentation, DEFAULT_FEATURE
from . import append_image_augmentation, DEFAULT_FEATURE
 
import os
 
import logging
 
logger = logging.getLogger(__name__)
def example_parser(serialized_example, feature, data_shape, data_type):
def example_parser(serialized_example, feature, data_shape, data_type):
@@ -389,3 +392,86 @@ def batch_data_and_labels_image_augmentation(tfrecord_filenames,
@@ -389,3 +392,86 @@ def batch_data_and_labels_image_augmentation(tfrecord_filenames,
features['key'] = key
features['key'] = key
return features, labels
return features, labels
 
 
 
def describe_tf_record(tf_record_path, shape, batch_size=1):
 
"""
 
Describe the number of samples and the number of classes of a tf-record
 
 
Parameters
 
----------
 
 
tf_record_path: str
 
Base path containing your tf-record files
 
 
shape: tuple
 
Shape inside of the tf-record
 
 
batch_size: int
 
Well, batch size
 
 
 
Returns
 
-------
 
number of samples and the number of classes in the TFRecord (or a set of them)
 
 
"""
 
 
 
tf_records = [os.path.join(tf_record_path, f) for f in os.listdir(tf_record_path)]
 
filename_queue = tf.train.string_input_producer(tf_records, num_epochs=1, name="input")
 
 
feature = {'data': tf.FixedLenFeature([], tf.string),
 
'label': tf.FixedLenFeature([], tf.int64),
 
'key': tf.FixedLenFeature([], tf.string)
 
}
 
 
# Define a reader and read the next record
 
reader = tf.TFRecordReader()
 
 
_, serialized_example = reader.read(filename_queue)
 
 
# Decode the record read by the reader
 
features = tf.parse_single_example(serialized_example, features=feature)
 
 
# Convert the image data from string back to the numbers
 
image = tf.decode_raw(features['data'], tf.uint8)
 
 
# Cast label data into int32
 
label = tf.cast(features['label'], tf.int64)
 
img_name = tf.cast(features['key'], tf.string)
 
 
# Reshape image data into the original shape
 
image = tf.reshape(image, shape)
 
 
# Getting the batches in order
 
data_ph, label_ph, img_name_ph = tf.train.batch([image, label, img_name], batch_size=batch_size,
 
capacity=1000, num_threads=5, name="shuffle_batch")
 
 
# Start the reading
 
session = tf.Session()
 
tf.local_variables_initializer().run(session=session)
 
tf.global_variables_initializer().run(session=session)
 
 
# Preparing the batches
 
thread_pool = tf.train.Coordinator()
 
threads = tf.train.start_queue_runners(coord=thread_pool, sess=session)
 
 
 
logger.info("Counting in %s", tf_record_path)
 
labels = set()
 
counter = 0
 
try:
 
while(True):
 
_, label, _ = session.run([data_ph, label_ph, img_name_ph])
 
counter += len(label)
 
 
for i in set(label):
 
labels.add(i)
 
 
except tf.errors.OutOfRangeError:
 
pass
 
 
thread_pool.request_stop()
 
return counter, len(labels)
 
Loading