From 3316ab71223df01970377235a29f5c7c027391fe Mon Sep 17 00:00:00 2001 From: Samuel Gaist <samuel.gaist@idiap.ch> Date: Mon, 23 Oct 2017 12:30:10 +0200 Subject: [PATCH] Removed code duplication between TFRecord and TFRecordImage. --- bob/learn/tensorflow/datashuffler/TFRecord.py | 44 +++++---- .../tensorflow/datashuffler/TFRecordImage.py | 95 +++++-------------- 2 files changed, 52 insertions(+), 87 deletions(-) diff --git a/bob/learn/tensorflow/datashuffler/TFRecord.py b/bob/learn/tensorflow/datashuffler/TFRecord.py index f63a76a3..a25f15d7 100755 --- a/bob/learn/tensorflow/datashuffler/TFRecord.py +++ b/bob/learn/tensorflow/datashuffler/TFRecord.py @@ -11,7 +11,7 @@ import numpy class TFRecord(object): """ Generic datashuffler that wraps the batching using tfrecords. - + **Parameters** filename_queue: Tensorflow producer @@ -19,7 +19,7 @@ class TFRecord(object): input_dtype: Type of the raw data saved in the tf record batch_size: Size of the batch seed: Seed - prefetch_capacity: Capacity of the bucket for prefetching + prefetch_capacity: Capacity of the bucket for prefetching prefetch_threads: Number of threads in the prefetching """ def __init__(self,filename_queue, @@ -36,7 +36,7 @@ class TFRecord(object): self.input_dtype = input_dtype - # TODO: Check if the bacth size is higher than the input data + # TODO: Check if the batch size is higher than the input data self.batch_size = batch_size # Preparing the inputs @@ -47,14 +47,14 @@ class TFRecord(object): self.prefetch = True self.prefetch_capacity = prefetch_capacity self.prefetch_threads = prefetch_threads - + self.data_ph = None self.label_ph = None + def __call__(self, element, from_queue=False): """ Return the necessary placeholder - """ if not element in ["data", "label"]: @@ -70,35 +70,47 @@ class TFRecord(object): return self.label_ph - def create_placeholders(self): + def __load_features(self): + """ + Load features from queue + """ feature = {'train/data': tf.FixedLenFeature([], tf.string), 'train/label': tf.FixedLenFeature([], tf.int64)} # Define a reader and read the next record reader = tf.TFRecordReader() - + _, serialized_example = reader.read(self.filename_queue) - - + + # Decode the record read by the reader features = tf.parse_single_example(serialized_example, features=feature) - + # Convert the image data from string back to the numbers image = tf.decode_raw(features['train/data'], self.input_dtype) - + # Cast label data into int32 label = tf.cast(features['train/label'], tf.int64) - + # Reshape image data into the original shape image = tf.reshape(image, self.input_shape[1:]) - - + + return image, label + + + def create_placeholders(self): + """ + Create placeholder data from features. + """ + image, label = self.__load_features() + + data_ph, label_ph = tf.train.shuffle_batch([image, label], batch_size=self.batch_size, capacity=self.prefetch_capacity, num_threads=self.prefetch_threads, min_after_dequeue=1, name="shuffle_batch") - - + + self.data_ph = data_ph self.label_ph = label_ph diff --git a/bob/learn/tensorflow/datashuffler/TFRecordImage.py b/bob/learn/tensorflow/datashuffler/TFRecordImage.py index ba3259ae..3f66b23b 100755 --- a/bob/learn/tensorflow/datashuffler/TFRecordImage.py +++ b/bob/learn/tensorflow/datashuffler/TFRecordImage.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # vim: set fileencoding=utf-8 : # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> -# @date: Wed 11 May 2016 09:39:36 CEST +# @date: Wed 11 May 2016 09:39:36 CEST import numpy import tensorflow as tf @@ -12,9 +12,9 @@ from .TFRecord import TFRecord class TFRecordImage(TFRecord): """ Datashuffler that wraps the batching using tfrecords. - - This shuffler is more suitable for image datasets, because it does image data augmentation operations. - + + This shuffler is more suitable for image datasets, because it does image data augmentation operations. + **Parameters** filename_queue: Tensorflow producer @@ -23,7 +23,7 @@ class TFRecordImage(TFRecord): input_dtype: Type of the raw data saved in the tf record batch_size: Size of the batch seed: Seed - prefetch_capacity: Capacity of the bucket for prefetching + prefetch_capacity: Capacity of the bucket for prefetching prefetch_threads: Number of threads in the prefetching shuffle: Shuffled the batch normalization: zero mean and unit std @@ -39,7 +39,7 @@ class TFRecordImage(TFRecord): batch_size=32, seed=10, prefetch_capacity=1000, - prefetch_threads=5, + prefetch_threads=5, shuffle=True, normalization=False, random_flip=True, @@ -47,87 +47,40 @@ class TFRecordImage(TFRecord): gray_scale=False ): - # Setting the seed for the pseudo random number generator - self.seed = seed - numpy.random.seed(seed) - - self.input_dtype = input_dtype - - # TODO: Check if the bacth size is higher than the input data - self.batch_size = batch_size - - # Preparing the inputs - self.filename_queue = filename_queue - self.input_shape = tuple(input_shape) + super(TFRecord, self).__init__(filename_queue=filename_queue, + input_shape=input_shape, + input_dtype=input_dtype, + batch_size=batch_size, + seed=seed, + prefetch_capacity=prefetch_capacity, + prefetch_threads=prefetch_threads) + # Preparing the output self.output_shape = output_shape - # Prefetch variables - self.prefetch = True - self.prefetch_capacity = prefetch_capacity - self.prefetch_threads = prefetch_threads - - self.data_ph = None - self.label_ph = None - self.shuffle = shuffle self.normalization = normalization self.random_crop = random_crop self.random_flip = random_flip self.gray_scale = gray_scale - def __call__(self, element, from_queue=False): - """ - Return the necessary placeholder - - """ - - if not element in ["data", "label"]: - raise ValueError("Value '{0}' invalid. Options available are {1}".format(element, self.placeholder_options)) - - # If None, create the placeholders from scratch - if self.data_ph is None: - self.create_placeholders() - - if element == "data": - return self.data_ph - else: - return self.label_ph - def create_placeholders(self): + """ + Reimplementation + """ - feature = {'train/data': tf.FixedLenFeature([], tf.string), - 'train/label': tf.FixedLenFeature([], tf.int64)} - - # Define a reader and read the next record - reader = tf.TFRecordReader() - - _, serialized_example = reader.read(self.filename_queue) - - - # Decode the record read by the reader - features = tf.parse_single_example(serialized_example, features=feature) - - # Convert the image data from string back to the numbers - image = tf.decode_raw(features['train/data'], self.input_dtype) - #image = tf.decode_raw(features['train/data'], tf.uint8) - - # Cast label data into int32 - label = tf.cast(features['train/label'], tf.int64) - - # Reshape image data into the original shape - image = tf.reshape(image, self.input_shape[1:]) + image, label = self.__load_features() # Casting to float32 image = tf.cast(image, tf.float32) - + if self.gray_scale: image = tf.image.rgb_to_grayscale(image, name="rgb_to_gray") self.output_shape[3] = 1 - + if self.random_crop: image = tf.image.resize_image_with_crop_or_pad(image, self.output_shape[1], self.output_shape[2]) - + if self.random_flip: image = tf.image.random_flip_left_right(image) @@ -136,7 +89,7 @@ class TFRecordImage(TFRecord): image = tf.image.per_image_standardization(image) image.set_shape(tuple(self.output_shape[1:])) - + if self.shuffle: data_ph, label_ph = tf.train.shuffle_batch([image, label], batch_size=self.batch_size, @@ -145,8 +98,8 @@ class TFRecordImage(TFRecord): else: data_ph, label_ph = tf.train.batch([image, label], batch_size=self.batch_size, capacity=self.prefetch_capacity, num_threads=self.prefetch_threads, name="batch") - - + + self.data_ph = data_ph self.label_ph = label_ph -- GitLab