diff --git a/bob/learn/tensorflow/dataset/__init__.py b/bob/learn/tensorflow/dataset/__init__.py index e0221e60a7a94e340d318532aee28e0fa9b469a7..8e0f584f276b4ae0c2fc4ba4440a751f53653c94 100755 --- a/bob/learn/tensorflow/dataset/__init__.py +++ b/bob/learn/tensorflow/dataset/__init__.py @@ -1,5 +1,5 @@ import tensorflow as tf - +import numpy DEFAULT_FEATURE = {'train/data': tf.FixedLenFeature([], tf.string), 'train/label': tf.FixedLenFeature([], tf.int64)} @@ -68,4 +68,80 @@ def append_image_augmentation(image, gray_scale=False, image = tf.image.per_image_standardization(image) return image + + +def siamease_pairs_generator(input_data, input_labels): + """ + Giving a list of samples and a list of labels, it dumps a series of + pairs for siamese nets. + + **Parameters** + + input_data: List of whatever representing the data samples + + input_labels: List of the labels (needs to be in EXACT same order as input_data) + """ + + # Lists that will be returned + left_data = [] + right_data = [] + labels = [] + + def append(left, right, label): + """ + Just appending one element in each list + """ + left_data.append(left) + right_data.append(right) + labels.append(label) + + possible_labels = list(set(input_labels)) + input_data = numpy.array(input_data) + input_labels = numpy.array(input_labels) + total_samples = input_data.shape[0] + + # Filtering the samples by label and shuffling all the indexes + indexes_per_labels = dict() + for l in possible_labels: + indexes_per_labels[l] = numpy.where(input_labels == l)[0] + numpy.random.shuffle(indexes_per_labels[l]) + + left_possible_indexes = numpy.random.choice(possible_labels, total_samples, replace=True) + right_possible_indexes = numpy.random.choice(possible_labels, total_samples, replace=True) + + genuine = True + for i in range(total_samples): + + if genuine: + # Selecting the class + class_index = left_possible_indexes[i] + + # Now selecting the samples for the pair + left = input_data[indexes_per_labels[class_index][numpy.random.randint(len(indexes_per_labels[class_index]))]] + right = input_data[indexes_per_labels[class_index][numpy.random.randint(len(indexes_per_labels[class_index]))]] + append(left, right, 0) + #yield left, right, 0 + else: + # Selecting the 2 classes + class_index = list() + class_index.append(left_possible_indexes[i]) + + # Finding the right pair + j = i + # TODO: Lame solution. Fix this + while j < total_samples: # Here is an unidiretinal search for the negative pair + if left_possible_indexes[i] != right_possible_indexes[j]: + class_index.append(right_possible_indexes[j]) + break + j += 1 + + if j < total_samples: + # Now selecting the samples for the pair + left = input_data[indexes_per_labels[class_index[0]][numpy.random.randint(len(indexes_per_labels[class_index[0]]))]] + right = input_data[indexes_per_labels[class_index[1]][numpy.random.randint(len(indexes_per_labels[class_index[1]]))]] + append(left, right, 1) + + + genuine = not genuine + return left_data, right_data, labels diff --git a/bob/learn/tensorflow/dataset/siamese_image.py b/bob/learn/tensorflow/dataset/siamese_image.py new file mode 100644 index 0000000000000000000000000000000000000000..d22e036051e99d00f06e309d7268050690f60d7c --- /dev/null +++ b/bob/learn/tensorflow/dataset/siamese_image.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + +import tensorflow as tf +from functools import partial +from . import append_image_augmentation, siamease_pairs_generator + + +def shuffle_data_and_labels_image_augmentation(filenames, labels, data_shape, data_type, + batch_size, epochs=None, buffer_size=10**3, + gray_scale=False, + output_shape=None, + random_flip=False, + random_brightness=False, + random_contrast=False, + random_saturation=False, + per_image_normalization=True): + """ + Dump random batches for siamese networks from a list of image paths and labels: + + The list of files and labels should be in the same order e.g. + filenames = ['class_1_img1', 'class_1_img2', 'class_2_img1'] + labels = [0, 0, 1] + + The batches returned with tf.Session.run() with be in the following format: + **data** a dictionary containing the keys ['left', 'right'], each one representing + one element of the pair and **labels** which is [0, 1] where 0 is the genuine pair + and 1 is the impostor pair. + + + **Parameters** + + filenames: + List containing the path of the images + + labels: + List containing the labels (needs to be in EXACT same order as filenames) + + data_shape: + Samples shape saved in the tf-record + + data_type: + tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) + + batch_size: + Size of the batch + + epochs: + Number of epochs to be batched + + buffer_size: + Size of the shuffle bucket + + gray_scale: + Convert to gray scale? + + output_shape: + If set, will randomly crop the image given the output shape + + random_flip: + Randomly flip an image horizontally (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right) + + random_brightness: + Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness) + + random_contrast: + Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast) + + random_saturation: + Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation) + + per_image_normalization: + Linearly scales image to have zero mean and unit norm. + + + """ + + dataset = create_dataset_from_path_augmentation(filenames, labels, data_shape, + data_type, + gray_scale=gray_scale, + output_shape=output_shape, + random_flip=random_flip, + random_brightness=random_brightness, + random_contrast=random_contrast, + random_saturation=random_saturation, + per_image_normalization=per_image_normalization) + + dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs) + #dataset = dataset.batch(buffer_size).batch(batch_size).repeat(epochs) + + data, labels = dataset.make_one_shot_iterator().get_next() + return data, labels + + +def create_dataset_from_path_augmentation(filenames, labels, + data_shape, data_type, + gray_scale=False, + output_shape=None, + random_flip=False, + random_brightness=False, + random_contrast=False, + random_saturation=False, + per_image_normalization=True): + """ + Create dataset from a list of tf-record files + + **Parameters** + + filenames: + List containing the path of the images + + labels: + List containing the labels (needs to be in EXACT same order as filenames) + + data_shape: + Samples shape saved in the tf-record + + data_type: + tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) + + feature: + + """ + + parser = partial(image_augmentation_parser, + data_shape=data_shape, + data_type=data_type, + gray_scale=gray_scale, + output_shape=output_shape, + random_flip=random_flip, + random_brightness=random_brightness, + random_contrast=random_contrast, + random_saturation=random_saturation, + per_image_normalization=per_image_normalization) + + left_data, right_data, siamese_labels = siamease_pairs_generator(filenames, labels) + dataset = tf.contrib.data.Dataset.from_tensor_slices((left_data, right_data, siamese_labels)) + dataset = dataset.map(parser) + return dataset + + +def image_augmentation_parser(filename_left, filename_right, label, data_shape, data_type, + gray_scale=False, + output_shape=None, + random_flip=False, + random_brightness=False, + random_contrast=False, + random_saturation=False, + per_image_normalization=True): + + """ + Parses a single tf.Example into image and label tensors. + """ + + # Convert the image data from string back to the numbers + image_left = tf.cast(tf.image.decode_image(tf.read_file(filename_left)), tf.float32) + image_right = tf.cast(tf.image.decode_image(tf.read_file(filename_right)), tf.float32) + + # Reshape image data into the original shape + image_left = tf.reshape(image_left, data_shape) + image_right = tf.reshape(image_right, data_shape) + + #Applying image augmentation + image_left = append_image_augmentation(image_left, gray_scale=gray_scale, + output_shape=output_shape, + random_flip=random_flip, + random_brightness=random_brightness, + random_contrast=random_contrast, + random_saturation=random_saturation, + per_image_normalization=per_image_normalization) + + image_right = append_image_augmentation(image_right, gray_scale=gray_scale, + output_shape=output_shape, + random_flip=random_flip, + random_brightness=random_brightness, + random_contrast=random_contrast, + random_saturation=random_saturation, + per_image_normalization=per_image_normalization) + + image = dict() + image['left'] = image_left + image['right'] = image_right + + label = tf.cast(label, tf.int64) + + return image, label + diff --git a/bob/learn/tensorflow/test/test_dataset.py b/bob/learn/tensorflow/test/test_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..f4fd0dacd79a05118b285f5f84e19062133bdeb0 --- /dev/null +++ b/bob/learn/tensorflow/test/test_dataset.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + +import pkg_resources +import tensorflow as tf +from bob.learn.tensorflow.dataset.siamese_image import shuffle_data_and_labels_image_augmentation as siamese_batch + +data_shape = (250, 250, 3) # size of atnt images +data_type = tf.float32 +batch_size = 2 +validation_batch_size = 250 +epochs = 1 + + +def test_siamese_dataset(): + + # Trainer logits + filenames = [pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'), + pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png')] + labels = [0, 0, 1, 1] + + data, label = siamese_batch(filenames, labels, data_shape, data_type, 2) + with tf.Session() as session: + d, l = session.run([data, label]) + assert len(l) == 2 + assert d['left'].shape == (2, 250, 250, 3) + assert d['right'].shape == (2, 250, 250, 3) +