Skip to content
Snippets Groups Projects
Commit 65cdd62c authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Created mechanism to batch siamease data

parent 1b13bc83
No related branches found
No related tags found
1 merge request!21Resolve "Adopt to the Estimators API"
import tensorflow as tf
import numpy
DEFAULT_FEATURE = {'train/data': tf.FixedLenFeature([], tf.string),
'train/label': tf.FixedLenFeature([], tf.int64)}
......@@ -68,4 +68,80 @@ def append_image_augmentation(image, gray_scale=False,
image = tf.image.per_image_standardization(image)
return image
def siamease_pairs_generator(input_data, input_labels):
"""
Giving a list of samples and a list of labels, it dumps a series of
pairs for siamese nets.
**Parameters**
input_data: List of whatever representing the data samples
input_labels: List of the labels (needs to be in EXACT same order as input_data)
"""
# Lists that will be returned
left_data = []
right_data = []
labels = []
def append(left, right, label):
"""
Just appending one element in each list
"""
left_data.append(left)
right_data.append(right)
labels.append(label)
possible_labels = list(set(input_labels))
input_data = numpy.array(input_data)
input_labels = numpy.array(input_labels)
total_samples = input_data.shape[0]
# Filtering the samples by label and shuffling all the indexes
indexes_per_labels = dict()
for l in possible_labels:
indexes_per_labels[l] = numpy.where(input_labels == l)[0]
numpy.random.shuffle(indexes_per_labels[l])
left_possible_indexes = numpy.random.choice(possible_labels, total_samples, replace=True)
right_possible_indexes = numpy.random.choice(possible_labels, total_samples, replace=True)
genuine = True
for i in range(total_samples):
if genuine:
# Selecting the class
class_index = left_possible_indexes[i]
# Now selecting the samples for the pair
left = input_data[indexes_per_labels[class_index][numpy.random.randint(len(indexes_per_labels[class_index]))]]
right = input_data[indexes_per_labels[class_index][numpy.random.randint(len(indexes_per_labels[class_index]))]]
append(left, right, 0)
#yield left, right, 0
else:
# Selecting the 2 classes
class_index = list()
class_index.append(left_possible_indexes[i])
# Finding the right pair
j = i
# TODO: Lame solution. Fix this
while j < total_samples: # Here is an unidiretinal search for the negative pair
if left_possible_indexes[i] != right_possible_indexes[j]:
class_index.append(right_possible_indexes[j])
break
j += 1
if j < total_samples:
# Now selecting the samples for the pair
left = input_data[indexes_per_labels[class_index[0]][numpy.random.randint(len(indexes_per_labels[class_index[0]]))]]
right = input_data[indexes_per_labels[class_index[1]][numpy.random.randint(len(indexes_per_labels[class_index[1]]))]]
append(left, right, 1)
genuine = not genuine
return left_data, right_data, labels
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import tensorflow as tf
from functools import partial
from . import append_image_augmentation, siamease_pairs_generator
def shuffle_data_and_labels_image_augmentation(filenames, labels, data_shape, data_type,
batch_size, epochs=None, buffer_size=10**3,
gray_scale=False,
output_shape=None,
random_flip=False,
random_brightness=False,
random_contrast=False,
random_saturation=False,
per_image_normalization=True):
"""
Dump random batches for siamese networks from a list of image paths and labels:
The list of files and labels should be in the same order e.g.
filenames = ['class_1_img1', 'class_1_img2', 'class_2_img1']
labels = [0, 0, 1]
The batches returned with tf.Session.run() with be in the following format:
**data** a dictionary containing the keys ['left', 'right'], each one representing
one element of the pair and **labels** which is [0, 1] where 0 is the genuine pair
and 1 is the impostor pair.
**Parameters**
filenames:
List containing the path of the images
labels:
List containing the labels (needs to be in EXACT same order as filenames)
data_shape:
Samples shape saved in the tf-record
data_type:
tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
batch_size:
Size of the batch
epochs:
Number of epochs to be batched
buffer_size:
Size of the shuffle bucket
gray_scale:
Convert to gray scale?
output_shape:
If set, will randomly crop the image given the output shape
random_flip:
Randomly flip an image horizontally (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right)
random_brightness:
Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness)
random_contrast:
Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast)
random_saturation:
Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation)
per_image_normalization:
Linearly scales image to have zero mean and unit norm.
"""
dataset = create_dataset_from_path_augmentation(filenames, labels, data_shape,
data_type,
gray_scale=gray_scale,
output_shape=output_shape,
random_flip=random_flip,
random_brightness=random_brightness,
random_contrast=random_contrast,
random_saturation=random_saturation,
per_image_normalization=per_image_normalization)
dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
#dataset = dataset.batch(buffer_size).batch(batch_size).repeat(epochs)
data, labels = dataset.make_one_shot_iterator().get_next()
return data, labels
def create_dataset_from_path_augmentation(filenames, labels,
data_shape, data_type,
gray_scale=False,
output_shape=None,
random_flip=False,
random_brightness=False,
random_contrast=False,
random_saturation=False,
per_image_normalization=True):
"""
Create dataset from a list of tf-record files
**Parameters**
filenames:
List containing the path of the images
labels:
List containing the labels (needs to be in EXACT same order as filenames)
data_shape:
Samples shape saved in the tf-record
data_type:
tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
feature:
"""
parser = partial(image_augmentation_parser,
data_shape=data_shape,
data_type=data_type,
gray_scale=gray_scale,
output_shape=output_shape,
random_flip=random_flip,
random_brightness=random_brightness,
random_contrast=random_contrast,
random_saturation=random_saturation,
per_image_normalization=per_image_normalization)
left_data, right_data, siamese_labels = siamease_pairs_generator(filenames, labels)
dataset = tf.contrib.data.Dataset.from_tensor_slices((left_data, right_data, siamese_labels))
dataset = dataset.map(parser)
return dataset
def image_augmentation_parser(filename_left, filename_right, label, data_shape, data_type,
gray_scale=False,
output_shape=None,
random_flip=False,
random_brightness=False,
random_contrast=False,
random_saturation=False,
per_image_normalization=True):
"""
Parses a single tf.Example into image and label tensors.
"""
# Convert the image data from string back to the numbers
image_left = tf.cast(tf.image.decode_image(tf.read_file(filename_left)), tf.float32)
image_right = tf.cast(tf.image.decode_image(tf.read_file(filename_right)), tf.float32)
# Reshape image data into the original shape
image_left = tf.reshape(image_left, data_shape)
image_right = tf.reshape(image_right, data_shape)
#Applying image augmentation
image_left = append_image_augmentation(image_left, gray_scale=gray_scale,
output_shape=output_shape,
random_flip=random_flip,
random_brightness=random_brightness,
random_contrast=random_contrast,
random_saturation=random_saturation,
per_image_normalization=per_image_normalization)
image_right = append_image_augmentation(image_right, gray_scale=gray_scale,
output_shape=output_shape,
random_flip=random_flip,
random_brightness=random_brightness,
random_contrast=random_contrast,
random_saturation=random_saturation,
per_image_normalization=per_image_normalization)
image = dict()
image['left'] = image_left
image['right'] = image_right
label = tf.cast(label, tf.int64)
return image, label
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import pkg_resources
import tensorflow as tf
from bob.learn.tensorflow.dataset.siamese_image import shuffle_data_and_labels_image_augmentation as siamese_batch
data_shape = (250, 250, 3) # size of atnt images
data_type = tf.float32
batch_size = 2
validation_batch_size = 250
epochs = 1
def test_siamese_dataset():
# Trainer logits
filenames = [pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png')]
labels = [0, 0, 1, 1]
data, label = siamese_batch(filenames, labels, data_shape, data_type, 2)
with tf.Session() as session:
d, l = session.run([data, label])
assert len(l) == 2
assert d['left'].shape == (2, 250, 250, 3)
assert d['right'].shape == (2, 250, 250, 3)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment