Skip to content
Snippets Groups Projects
Commit 26fda161 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented random triplet batching

parent 34d01c01
No related branches found
No related tags found
1 merge request!24Implemented random triplet batching
Pipeline #
......@@ -70,6 +70,67 @@ def append_image_augmentation(image, gray_scale=False,
return image
def arrange_indexes_by_label(input_labels, possible_labels):
# Shuffling all the indexes
indexes_per_labels = dict()
for l in possible_labels:
indexes_per_labels[l] = numpy.where(input_labels == l)[0]
numpy.random.shuffle(indexes_per_labels[l])
return indexes_per_labels
def triplets_random_generator(input_data, input_labels):
"""
Giving a list of samples and a list of labels, it dumps a series of
triplets for triple nets.
**Parameters**
input_data: List of whatever representing the data samples
input_labels: List of the labels (needs to be in EXACT same order as input_data)
"""
anchor = []
positive = []
negative = []
def append(anchor_sample, positive_sample, negative_sample):
"""
Just appending one element in each list
"""
anchor.append(anchor_sample)
positive.append(positive_sample)
negative.append(negative_sample)
possible_labels = list(set(input_labels))
input_data = numpy.array(input_data)
input_labels = numpy.array(input_labels)
total_samples = input_data.shape[0]
indexes_per_labels = arrange_indexes_by_label(input_labels, possible_labels)
# searching for random triplets
offset_class = 0
for i in range(total_samples):
anchor_sample = input_data[indexes_per_labels[possible_labels[offset_class]][numpy.random.randint(len(indexes_per_labels[possible_labels[offset_class]]))], ...]
positive_sample = input_data[indexes_per_labels[possible_labels[offset_class]][numpy.random.randint(len(indexes_per_labels[possible_labels[offset_class]]))], ...]
# Changing the class
offset_class += 1
if offset_class == len(possible_labels):
offset_class = 0
negative_sample = input_data[indexes_per_labels[possible_labels[offset_class]][numpy.random.randint(len(indexes_per_labels[possible_labels[offset_class]]))], ...]
append(str(anchor_sample), str(positive_sample), str(negative_sample))
#yield anchor, positive, negative
return anchor, positive, negative
def siamease_pairs_generator(input_data, input_labels):
"""
Giving a list of samples and a list of labels, it dumps a series of
......@@ -101,10 +162,11 @@ def siamease_pairs_generator(input_data, input_labels):
total_samples = input_data.shape[0]
# Filtering the samples by label and shuffling all the indexes
indexes_per_labels = dict()
for l in possible_labels:
indexes_per_labels[l] = numpy.where(input_labels == l)[0]
numpy.random.shuffle(indexes_per_labels[l])
#indexes_per_labels = dict()
#for l in possible_labels:
# indexes_per_labels[l] = numpy.where(input_labels == l)[0]
# numpy.random.shuffle(indexes_per_labels[l])
indexes_per_labels = arrange_indexes_by_label(input_labels, possible_labels)
left_possible_indexes = numpy.random.choice(possible_labels, total_samples, replace=True)
right_possible_indexes = numpy.random.choice(possible_labels, total_samples, replace=True)
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import tensorflow as tf
from functools import partial
from . import append_image_augmentation, triplets_random_generator
def shuffle_data_and_labels_image_augmentation(filenames, labels, data_shape, data_type,
batch_size, epochs=None, buffer_size=10**3,
gray_scale=False,
output_shape=None,
random_flip=False,
random_brightness=False,
random_contrast=False,
random_saturation=False,
per_image_normalization=True):
"""
Dump random batches for triplee networks from a list of image paths and labels:
The list of files and labels should be in the same order e.g.
filenames = ['class_1_img1', 'class_1_img2', 'class_2_img1']
labels = [0, 0, 1]
The batches returned with tf.Session.run() with be in the following format:
**data** a dictionary containing the keys ['anchor', 'positive', 'negative'].
**Parameters**
filenames:
List containing the path of the images
labels:
List containing the labels (needs to be in EXACT same order as filenames)
data_shape:
Samples shape saved in the tf-record
data_type:
tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
batch_size:
Size of the batch
epochs:
Number of epochs to be batched
buffer_size:
Size of the shuffle bucket
gray_scale:
Convert to gray scale?
output_shape:
If set, will randomly crop the image given the output shape
random_flip:
Randomly flip an image horizontally (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right)
random_brightness:
Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness)
random_contrast:
Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast)
random_saturation:
Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation)
per_image_normalization:
Linearly scales image to have zero mean and unit norm.
"""
dataset = create_dataset_from_path_augmentation(filenames, labels, data_shape,
data_type,
gray_scale=gray_scale,
output_shape=output_shape,
random_flip=random_flip,
random_brightness=random_brightness,
random_contrast=random_contrast,
random_saturation=random_saturation,
per_image_normalization=per_image_normalization)
dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
#dataset = dataset.batch(buffer_size).batch(batch_size).repeat(epochs)
data = dataset.make_one_shot_iterator().get_next()
return data
def create_dataset_from_path_augmentation(filenames, labels,
data_shape, data_type=tf.float32,
gray_scale=False,
output_shape=None,
random_flip=False,
random_brightness=False,
random_contrast=False,
random_saturation=False,
per_image_normalization=True):
"""
Create dataset from a list of tf-record files
**Parameters**
filenames:
List containing the path of the images
labels:
List containing the labels (needs to be in EXACT same order as filenames)
data_shape:
Samples shape saved in the tf-record
data_type:
tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
feature:
"""
parser = partial(image_augmentation_parser,
data_shape=data_shape,
data_type=data_type,
gray_scale=gray_scale,
output_shape=output_shape,
random_flip=random_flip,
random_brightness=random_brightness,
random_contrast=random_contrast,
random_saturation=random_saturation,
per_image_normalization=per_image_normalization)
anchor_data, positive_data, negative_data = triplets_random_generator(filenames, labels)
dataset = tf.contrib.data.Dataset.from_tensor_slices((anchor_data, positive_data, negative_data))
dataset = dataset.map(parser)
return dataset
def image_augmentation_parser(anchor, positive, negative, data_shape, data_type=tf.float32,
gray_scale=False,
output_shape=None,
random_flip=False,
random_brightness=False,
random_contrast=False,
random_saturation=False,
per_image_normalization=True):
"""
Parses a single tf.Example into image and label tensors.
"""
triplet = dict()
for n, v in zip(['anchor', 'positive', 'negative'], [anchor, positive, negative]):
# Convert the image data from string back to the numbers
image = tf.cast(tf.image.decode_image(tf.read_file(v)), data_type)
# Reshape image data into the original shape
image = tf.reshape(image, data_shape)
# Applying image augmentation
image = append_image_augmentation(image, gray_scale=gray_scale,
output_shape=output_shape,
random_flip=random_flip,
random_brightness=random_brightness,
random_contrast=random_contrast,
random_saturation=random_saturation,
per_image_normalization=per_image_normalization)
triplet[n] = image
return triplet
......@@ -5,6 +5,7 @@
import pkg_resources
import tensorflow as tf
from bob.learn.tensorflow.dataset.siamese_image import shuffle_data_and_labels_image_augmentation as siamese_batch
from bob.learn.tensorflow.dataset.triplet_image import shuffle_data_and_labels_image_augmentation as triplet_batch
data_shape = (250, 250, 3)
output_shape = (50, 50)
......@@ -14,24 +15,25 @@ validation_batch_size = 250
epochs = 1
def test_siamese_dataset():
# Trainer logits
filenames = [pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png')]
labels = [0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1]
# Trainer logits
filenames = [pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),
pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png')]
labels = [0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1]
def test_siamese_dataset():
data, label = siamese_batch(filenames, labels, data_shape, data_type, 2, per_image_normalization=False, output_shape=output_shape)
......@@ -41,3 +43,13 @@ def test_siamese_dataset():
assert d['left'].shape == (2, 50, 50, 3)
assert d['right'].shape == (2, 50, 50, 3)
def test_triplet_dataset():
data = triplet_batch(filenames, labels, data_shape, data_type, 2, per_image_normalization=False, output_shape=output_shape)
with tf.Session() as session:
d = session.run([data])[0]
assert len(d.keys()) == 3
assert d['anchor'].shape == (2, 50, 50, 3)
assert d['positive'].shape == (2, 50, 50, 3)
assert d['negative'].shape == (2, 50, 50, 3)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment