From 046ec4f5381b2687abf94cb5eb5e6e578b4bbc62 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Wed, 8 Aug 2018 12:05:47 +0200
Subject: [PATCH] Created click command to count elements in the tf-record

Fixed typo

Refactored the command line interface

Implemented test case

Changed prints to click.echo

[sphinx] Fixed warning
---
 bob/learn/tensorflow/dataset/tfrecords.py     | 91 +++++++++++++++++++
 .../tensorflow/script/db_to_tfrecords.py      | 39 ++++++++
 .../tensorflow/test/test_db_to_tfrecords.py   | 28 +++++-
 setup.py                                      |  3 +-
 4 files changed, 157 insertions(+), 4 deletions(-)

diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py
index c04a8d50..1ce05f12 100644
--- a/bob/learn/tensorflow/dataset/tfrecords.py
+++ b/bob/learn/tensorflow/dataset/tfrecords.py
@@ -1,6 +1,9 @@
 from functools import partial
 import tensorflow as tf
 from . import append_image_augmentation, DEFAULT_FEATURE
+import os
+import logging
+logger = logging.getLogger(__name__)
 
 
 def example_parser(serialized_example, feature, data_shape, data_type):
@@ -389,3 +392,91 @@ def batch_data_and_labels_image_augmentation(tfrecord_filenames,
     features['key'] = key
 
     return features, labels
+
+
+def describe_tf_record(tf_record_path, shape, batch_size=1):
+    """
+    Describe the number of samples and the number of classes of a tf-record
+
+    Parameters
+    ----------
+    
+    tf_record_path: str
+      Base path containing your tf-record files
+
+    shape: tuple
+       Shape inside of the tf-record
+
+    batch_size: int
+      Well, batch size
+
+
+    Returns
+    -------
+
+    n_samples: int
+       Total number of samples
+
+    n_classes: int
+       Total number of classes
+
+    """
+
+
+    tf_records = [os.path.join(tf_record_path, f) for f in os.listdir(tf_record_path)]
+    filename_queue = tf.train.string_input_producer(tf_records, num_epochs=1, name="input")
+
+    feature = {'data': tf.FixedLenFeature([], tf.string),
+               'label': tf.FixedLenFeature([], tf.int64),
+               'key': tf.FixedLenFeature([], tf.string)
+               }
+
+    # Define a reader and read the next record
+    reader = tf.TFRecordReader()
+    
+    _, serialized_example = reader.read(filename_queue)
+    
+    # Decode the record read by the reader
+    features = tf.parse_single_example(serialized_example, features=feature)
+    
+    # Convert the image data from string back to the numbers
+    image = tf.decode_raw(features['data'], tf.uint8)
+        
+    # Cast label data into int32
+    label = tf.cast(features['label'], tf.int64)
+    img_name = tf.cast(features['key'], tf.string)
+    
+    # Reshape image data into the original shape
+    image = tf.reshape(image, shape)        
+    
+    # Getting the batches in order
+    data_ph, label_ph, img_name_ph = tf.train.batch([image, label, img_name], batch_size=batch_size,
+                     capacity=1000, num_threads=5, name="shuffle_batch")
+
+    # Start the reading
+    session = tf.Session()
+    tf.local_variables_initializer().run(session=session)
+    tf.global_variables_initializer().run(session=session)
+
+    # Preparing the batches
+    thread_pool = tf.train.Coordinator()
+    threads = tf.train.start_queue_runners(coord=thread_pool, sess=session)
+
+
+    logger.info("Counting in %s", tf_record_path)
+    labels = set()
+    counter = 0
+    try:
+        while(True):
+            _, label, _ = session.run([data_ph, label_ph, img_name_ph])
+            counter += len(label)
+
+            for i in set(label):
+                labels.add(i)
+
+    except tf.errors.OutOfRangeError:
+        pass
+        
+    thread_pool.request_stop()
+    return counter, len(labels)
+
diff --git a/bob/learn/tensorflow/script/db_to_tfrecords.py b/bob/learn/tensorflow/script/db_to_tfrecords.py
index b24f78aa..b5948a85 100644
--- a/bob/learn/tensorflow/script/db_to_tfrecords.py
+++ b/bob/learn/tensorflow/script/db_to_tfrecords.py
@@ -14,6 +14,9 @@ import tensorflow as tf
 from bob.io.base import create_directories_safe
 from bob.extension.scripts.click_helper import (
     verbosity_option, ConfigCommand, ResourceOption, log_parameters)
+import numpy
+from bob.learn.tensorflow.dataset.tfrecords import describe_tf_record
+
 
 logger = logging.getLogger(__name__)
 
@@ -230,3 +233,39 @@ def db_to_tfrecords(samples, reader, output, shuffle, allow_failures,
             pass
     click.echo("The total size of the tfrecords file will be roughly "
                "{} bytes".format(_bytes2human(total_size)))
+               
+
+@click.command()
+@click.argument(
+    'tf-record-path',
+    nargs=1)
+@click.argument(
+    'shape',
+    type=int,
+    nargs=-1
+    )
+@click.option(
+    '--batch-size',
+    help='Batch size',
+    show_default=True,
+    required=True,
+    default=1000
+    ) 
+@verbosity_option(cls=ResourceOption)
+def describe_tfrecord(tf_record_path, shape, batch_size, **kwargs):
+    '''
+    Very often you have a tf-record file, or a set of them, and you have no idea
+    how many samples you have there.
+    Even worse, you have no idea how many classes you have.
+    
+    This click command will solve this thing for you by doing the following::
+    
+        $ %(prog)s <tf-record-path> 182 182 3
+    
+    '''
+    n_samples, n_labels = describe_tf_record(tf_record_path, shape, batch_size)
+    click.echo("#############################################")
+    click.echo("Number of samples {0}".format(n_samples))
+    click.echo("Number of labels {0}".format(n_labels))
+    click.echo("#############################################")
+
diff --git a/bob/learn/tensorflow/test/test_db_to_tfrecords.py b/bob/learn/tensorflow/test/test_db_to_tfrecords.py
index 5027fcfb..f990c06d 100644
--- a/bob/learn/tensorflow/test/test_db_to_tfrecords.py
+++ b/bob/learn/tensorflow/test/test_db_to_tfrecords.py
@@ -3,9 +3,10 @@ import shutil
 import pkg_resources
 import tempfile
 from click.testing import CliRunner
-
-from bob.learn.tensorflow.script.db_to_tfrecords import db_to_tfrecords
-
+import bob.io.base
+from bob.learn.tensorflow.script.db_to_tfrecords import db_to_tfrecords, describe_tf_record
+from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord
+ 
 regenerate_reference = False
 
 dummy_config = pkg_resources.resource_filename(
@@ -47,3 +48,24 @@ def test_db_to_tfrecords_size_estimate():
 
     finally:
         shutil.rmtree(test_dir)
+
+
+def test_tfrecord_counter():
+    tfrecord_train = "./tf-train-test/train_mnist.tfrecord"
+    shape = (3136,) # I'm saving the thing as float
+    batch_size = 1000
+
+    try:
+        train_data, train_labels, validation_data, validation_labels = load_mnist()
+        bob.io.base.create_directories_safe(os.path.dirname(tfrecord_train))
+        create_mnist_tfrecord(
+            tfrecord_train, train_data, train_labels, n_samples=6000)
+
+        n_samples, n_labels = describe_tf_record(os.path.dirname(tfrecord_train), shape, batch_size)
+        
+        assert n_samples == 6000
+        assert n_labels == 10
+
+    finally:
+        shutil.rmtree(os.path.dirname(tfrecord_train))
+
diff --git a/setup.py b/setup.py
index 8a0a51a7..ec23c444 100644
--- a/setup.py
+++ b/setup.py
@@ -52,11 +52,12 @@ setup(
         'bob.learn.tensorflow.cli': [
             'compute_statistics = bob.learn.tensorflow.script.compute_statistics:compute_statistics',
             'db_to_tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:db_to_tfrecords',
+            'describe_tfrecord = bob.learn.tensorflow.script.db_to_tfrecords:describe_tfrecord',
             'eval = bob.learn.tensorflow.script.eval:eval',
             'predict_bio = bob.learn.tensorflow.script.predict_bio:predict_bio',
             'train = bob.learn.tensorflow.script.train:train',
             'train_and_evaluate = bob.learn.tensorflow.script.train_and_evaluate:train_and_evaluate',
-            'style_transfer = bob.learn.tensorflow.script.style_transfer:style_transfer'
+            'style_transfer = bob.learn.tensorflow.script.style_transfer:style_transfer',
         ],
     },
 
-- 
GitLab