From 3316ab71223df01970377235a29f5c7c027391fe Mon Sep 17 00:00:00 2001
From: Samuel Gaist <samuel.gaist@idiap.ch>
Date: Mon, 23 Oct 2017 12:30:10 +0200
Subject: [PATCH] Removed code duplication between TFRecord and TFRecordImage.

---
 bob/learn/tensorflow/datashuffler/TFRecord.py | 44 +++++----
 .../tensorflow/datashuffler/TFRecordImage.py  | 95 +++++--------------
 2 files changed, 52 insertions(+), 87 deletions(-)

diff --git a/bob/learn/tensorflow/datashuffler/TFRecord.py b/bob/learn/tensorflow/datashuffler/TFRecord.py
index f63a76a3..a25f15d7 100755
--- a/bob/learn/tensorflow/datashuffler/TFRecord.py
+++ b/bob/learn/tensorflow/datashuffler/TFRecord.py
@@ -11,7 +11,7 @@ import numpy
 class TFRecord(object):
     """
     Generic datashuffler that wraps the batching using tfrecords.
-    
+
     **Parameters**
 
       filename_queue: Tensorflow producer
@@ -19,7 +19,7 @@ class TFRecord(object):
       input_dtype: Type of the raw data saved in the tf record
       batch_size: Size of the batch
       seed: Seed
-      prefetch_capacity: Capacity of the bucket for prefetching 
+      prefetch_capacity: Capacity of the bucket for prefetching
       prefetch_threads: Number of threads in the prefetching
     """
     def __init__(self,filename_queue,
@@ -36,7 +36,7 @@ class TFRecord(object):
 
         self.input_dtype = input_dtype
 
-        # TODO: Check if the bacth size is higher than the input data
+        # TODO: Check if the batch size is higher than the input data
         self.batch_size = batch_size
 
         # Preparing the inputs
@@ -47,14 +47,14 @@ class TFRecord(object):
         self.prefetch = True
         self.prefetch_capacity = prefetch_capacity
         self.prefetch_threads = prefetch_threads
-        
+
         self.data_ph = None
         self.label_ph = None
 
+
     def __call__(self, element, from_queue=False):
         """
         Return the necessary placeholder
-        
         """
 
         if not element in ["data", "label"]:
@@ -70,35 +70,47 @@ class TFRecord(object):
             return self.label_ph
 
 
-    def create_placeholders(self):
+    def __load_features(self):
+        """
+        Load features from queue
+        """
 
         feature = {'train/data': tf.FixedLenFeature([], tf.string),
                    'train/label': tf.FixedLenFeature([], tf.int64)}
 
         # Define a reader and read the next record
         reader = tf.TFRecordReader()
-        
+
         _, serialized_example = reader.read(self.filename_queue)
-        
-        
+
+
         # Decode the record read by the reader
         features = tf.parse_single_example(serialized_example, features=feature)
-        
+
         # Convert the image data from string back to the numbers
         image = tf.decode_raw(features['train/data'], self.input_dtype)
-        
+
         # Cast label data into int32
         label = tf.cast(features['train/label'], tf.int64)
-        
+
         # Reshape image data into the original shape
         image = tf.reshape(image, self.input_shape[1:])
-        
-        
+
+        return image, label
+
+
+    def create_placeholders(self):
+        """
+        Create placeholder data from features.
+        """
+        image, label = self.__load_features()
+
+
         data_ph, label_ph = tf.train.shuffle_batch([image, label], batch_size=self.batch_size,
                          capacity=self.prefetch_capacity, num_threads=self.prefetch_threads,
                          min_after_dequeue=1, name="shuffle_batch")
-        
-        
+
+
         self.data_ph = data_ph
         self.label_ph = label_ph
 
diff --git a/bob/learn/tensorflow/datashuffler/TFRecordImage.py b/bob/learn/tensorflow/datashuffler/TFRecordImage.py
index ba3259ae..3f66b23b 100755
--- a/bob/learn/tensorflow/datashuffler/TFRecordImage.py
+++ b/bob/learn/tensorflow/datashuffler/TFRecordImage.py
@@ -1,7 +1,7 @@
 #!/usr/bin/env python
 # vim: set fileencoding=utf-8 :
 # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-# @date: Wed 11 May 2016 09:39:36 CEST 
+# @date: Wed 11 May 2016 09:39:36 CEST
 
 import numpy
 import tensorflow as tf
@@ -12,9 +12,9 @@ from .TFRecord import TFRecord
 class TFRecordImage(TFRecord):
     """
     Datashuffler that wraps the batching using tfrecords.
-    
-    This shuffler is more suitable for image datasets, because it does image data augmentation operations.    
-    
+
+    This shuffler is more suitable for image datasets, because it does image data augmentation operations.
+
     **Parameters**
 
       filename_queue: Tensorflow producer
@@ -23,7 +23,7 @@ class TFRecordImage(TFRecord):
       input_dtype: Type of the raw data saved in the tf record
       batch_size: Size of the batch
       seed: Seed
-      prefetch_capacity: Capacity of the bucket for prefetching 
+      prefetch_capacity: Capacity of the bucket for prefetching
       prefetch_threads: Number of threads in the prefetching
       shuffle: Shuffled the batch
       normalization: zero mean and unit std
@@ -39,7 +39,7 @@ class TFRecordImage(TFRecord):
                          batch_size=32,
                          seed=10,
                          prefetch_capacity=1000,
-                         prefetch_threads=5, 
+                         prefetch_threads=5,
                          shuffle=True,
                          normalization=False,
                          random_flip=True,
@@ -47,87 +47,40 @@ class TFRecordImage(TFRecord):
                          gray_scale=False
                          ):
 
-        # Setting the seed for the pseudo random number generator
-        self.seed = seed
-        numpy.random.seed(seed)
-
-        self.input_dtype = input_dtype
-
-        # TODO: Check if the bacth size is higher than the input data
-        self.batch_size = batch_size
-
-        # Preparing the inputs
-        self.filename_queue = filename_queue
-        self.input_shape = tuple(input_shape)
+        super(TFRecord, self).__init__(filename_queue=filename_queue,
+                                       input_shape=input_shape,
+                                       input_dtype=input_dtype,
+                                       batch_size=batch_size,
+                                       seed=seed,
+                                       prefetch_capacity=prefetch_capacity,
+                                       prefetch_threads=prefetch_threads)
+        # Preparing the output
         self.output_shape = output_shape
 
-        # Prefetch variables
-        self.prefetch = True
-        self.prefetch_capacity = prefetch_capacity
-        self.prefetch_threads = prefetch_threads
-        
-        self.data_ph = None
-        self.label_ph = None
-        
         self.shuffle = shuffle
         self.normalization = normalization
         self.random_crop = random_crop
         self.random_flip = random_flip
         self.gray_scale = gray_scale
 
-    def __call__(self, element, from_queue=False):
-        """
-        Return the necessary placeholder
-        
-        """
-
-        if not element in ["data", "label"]:
-            raise ValueError("Value '{0}' invalid. Options available are {1}".format(element, self.placeholder_options))
-
-        # If None, create the placeholders from scratch
-        if self.data_ph is None:
-            self.create_placeholders()
-
-        if element == "data":
-            return self.data_ph
-        else:
-            return self.label_ph
-
 
     def create_placeholders(self):
+        """
+        Reimplementation
+        """
 
-        feature = {'train/data': tf.FixedLenFeature([], tf.string),
-                   'train/label': tf.FixedLenFeature([], tf.int64)}
-
-        # Define a reader and read the next record
-        reader = tf.TFRecordReader()
-        
-        _, serialized_example = reader.read(self.filename_queue)
-        
-        
-        # Decode the record read by the reader
-        features = tf.parse_single_example(serialized_example, features=feature)
-        
-        # Convert the image data from string back to the numbers
-        image = tf.decode_raw(features['train/data'], self.input_dtype)
-        #image = tf.decode_raw(features['train/data'], tf.uint8)
-        
-        # Cast label data into int32
-        label = tf.cast(features['train/label'], tf.int64)
-        
-        # Reshape image data into the original shape
-        image = tf.reshape(image, self.input_shape[1:])
+        image, label = self.__load_features()
 
         # Casting to float32
         image = tf.cast(image, tf.float32)
-        
+
         if self.gray_scale:
             image = tf.image.rgb_to_grayscale(image, name="rgb_to_gray")
             self.output_shape[3] = 1
-        
+
         if self.random_crop:
             image = tf.image.resize_image_with_crop_or_pad(image, self.output_shape[1], self.output_shape[2])
-            
+
         if self.random_flip:
             image = tf.image.random_flip_left_right(image)
 
@@ -136,7 +89,7 @@ class TFRecordImage(TFRecord):
             image = tf.image.per_image_standardization(image)
 
         image.set_shape(tuple(self.output_shape[1:]))
-        
+
 
         if self.shuffle:
             data_ph, label_ph = tf.train.shuffle_batch([image, label], batch_size=self.batch_size,
@@ -145,8 +98,8 @@ class TFRecordImage(TFRecord):
         else:
             data_ph, label_ph = tf.train.batch([image, label], batch_size=self.batch_size,
                              capacity=self.prefetch_capacity, num_threads=self.prefetch_threads, name="batch")
-        
-        
+
+
         self.data_ph = data_ph
         self.label_ph = label_ph
 
-- 
GitLab