Implemented the prefetch for the siamese trainer

parent 86f2c9cf
...@@ -45,14 +45,14 @@ class Analizer: ...@@ -45,14 +45,14 @@ class Analizer:
def __call__(self): def __call__(self):
# Extracting features for enrollment # Extracting features for enrollment
enroll_data, enroll_labels = self.data_shuffler.get_batch(train_dataset=False) enroll_data, enroll_labels = self.data_shuffler.get_batch()
enroll_features = self.machine(enroll_data, session=self.session) enroll_features = self.machine(enroll_data, session=self.session)
del enroll_data del enroll_data
#import ipdb; ipdb.set_trace(); #import ipdb; ipdb.set_trace();
# Extracting features for probing # Extracting features for probing
probe_data, probe_labels = self.data_shuffler.get_batch(train_dataset=False) probe_data, probe_labels = self.data_shuffler.get_batch()
probe_features = self.machine(probe_data, session=self.session) probe_features = self.machine(probe_data, session=self.session)
del probe_data del probe_data
......
...@@ -11,10 +11,8 @@ class BaseDataShuffler(object): ...@@ -11,10 +11,8 @@ class BaseDataShuffler(object):
def __init__(self, data, labels, def __init__(self, data, labels,
input_shape, input_shape,
input_dtype="float64", input_dtype="float64",
perc_train=0.9,
scale=True, scale=True,
train_batch_size=1, batch_size=1):
validation_batch_size=300):
""" """
The class provide base functionoalies to shuffle the data The class provide base functionoalies to shuffle the data
...@@ -32,54 +30,36 @@ class BaseDataShuffler(object): ...@@ -32,54 +30,36 @@ class BaseDataShuffler(object):
self.input_dtype = input_dtype self.input_dtype = input_dtype
# TODO: Check if the bacth size is higher than the input data # TODO: Check if the bacth size is higher than the input data
self.train_batch_size = train_batch_size self.batch_size = batch_size
self.validation_batch_size = validation_batch_size
self.data = data self.data = data
self.train_shape = tuple([train_batch_size] + input_shape) self.shape = tuple([batch_size] + input_shape)
self.validation_shape = tuple([validation_batch_size] + input_shape)
# TODO: Check if the labels goes from O to N-1
self.labels = labels self.labels = labels
self.possible_labels = list(set(self.labels)) self.possible_labels = list(set(self.labels))
# Computing the data samples fro train and validation # Computing the data samples fro train and validation
self.n_samples = len(self.labels) self.n_samples = len(self.labels)
self.n_train_samples = int(round(self.n_samples * perc_train))
self.n_validation_samples = self.n_samples - self.n_train_samples
# Shuffling all the indexes # Shuffling all the indexes
self.indexes = numpy.array(range(self.n_samples)) self.indexes = numpy.array(range(self.n_samples))
numpy.random.shuffle(self.indexes) numpy.random.shuffle(self.indexes)
# Spliting the data between train and validation def get_placeholders_forprefetch(self, name=""):
self.train_data = self.data[self.indexes[0:self.n_train_samples], ...]
self.train_labels = self.labels[self.indexes[0:self.n_train_samples]]
self.validation_data = self.data[self.indexes[self.n_train_samples:
self.n_train_samples + self.n_validation_samples], ...]
self.validation_labels = self.labels[self.indexes[self.n_train_samples:
self.n_train_samples + self.n_validation_samples]]
def get_placeholders_forprefetch(self, name="", train_dataset=True):
""" """
Returns a place holder with the size of your batch Returns a place holder with the size of your batch
""" """
data = tf.placeholder(tf.float32, shape=tuple([None] + list(self.shape[1:])), name=name)
shape = self.train_shape if train_dataset else self.validation_shape
data = tf.placeholder(tf.float32, shape=tuple([None] + list(shape[1:])), name=name)
labels = tf.placeholder(tf.int64, shape=[None, ]) labels = tf.placeholder(tf.int64, shape=[None, ])
return data, labels return data, labels
def get_placeholders(self, name="", train_dataset=True): def get_placeholders(self, name=""):
""" """
Returns a place holder with the size of your batch Returns a place holder with the size of your batch
""" """
data = tf.placeholder(tf.float32, shape=self.shape, name=name)
shape = self.train_shape if train_dataset else self.validation_shape labels = tf.placeholder(tf.int64, shape=self.shape[0])
data = tf.placeholder(tf.float32, shape=shape, name=name)
labels = tf.placeholder(tf.int64, shape=shape[0])
return data, labels return data, labels
......
...@@ -20,10 +20,8 @@ class MemoryDataShuffler(BaseDataShuffler): ...@@ -20,10 +20,8 @@ class MemoryDataShuffler(BaseDataShuffler):
def __init__(self, data, labels, def __init__(self, data, labels,
input_shape, input_shape,
input_dtype="float64", input_dtype="float64",
perc_train=0.9,
scale=True, scale=True,
train_batch_size=1, batch_size=1):
validation_batch_size=300):
""" """
Shuffler that deal with memory datasets Shuffler that deal with memory datasets
...@@ -41,36 +39,22 @@ class MemoryDataShuffler(BaseDataShuffler): ...@@ -41,36 +39,22 @@ class MemoryDataShuffler(BaseDataShuffler):
labels=labels, labels=labels,
input_shape=input_shape, input_shape=input_shape,
input_dtype=input_dtype, input_dtype=input_dtype,
perc_train=perc_train,
scale=scale, scale=scale,
train_batch_size=train_batch_size, batch_size=batch_size
validation_batch_size=validation_batch_size
) )
self.train_data = self.train_data.astype(input_dtype) self.data = self.data.astype(input_dtype)
self.validation_data = self.validation_data.astype(input_dtype)
if self.scale: if self.scale:
self.train_data *= self.scale_value self.data *= self.scale_value
self.validation_data *= self.scale_value
def get_batch(self, train_dataset=True):
if train_dataset: def get_batch(self):
n_samples = self.train_batch_size
data = self.train_data
label = self.train_labels
else:
n_samples = self.validation_batch_size
data = self.validation_data
label = self.validation_labels
# Shuffling samples # Shuffling samples
indexes = numpy.array(range(data.shape[0])) indexes = numpy.array(range(self.data.shape[0]))
numpy.random.shuffle(indexes) numpy.random.shuffle(indexes)
selected_data = data[indexes[0:n_samples], :, :, :] selected_data = self.data[indexes[0:self.batch_size], :, :, :]
selected_labels = label[indexes[0:n_samples]] selected_labels = self.labels[indexes[0:self.batch_size]]
return selected_data.astype("float32"), selected_labels return selected_data.astype("float32"), selected_labels
...@@ -83,23 +67,13 @@ class MemoryDataShuffler(BaseDataShuffler): ...@@ -83,23 +67,13 @@ class MemoryDataShuffler(BaseDataShuffler):
**Return** **Return**
""" """
data = numpy.zeros(shape=self.shape, dtype='float32')
if train_dataset: data_p = numpy.zeros(shape=self.shape, dtype='float32')
target_data = self.train_data labels_siamese = numpy.zeros(shape=self.shape[0], dtype='float32')
target_labels = self.train_labels
shape = self.train_shape
else:
target_data = self.validation_data
target_labels = self.validation_labels
shape = self.validation_shape
data = numpy.zeros(shape=shape, dtype='float32')
data_p = numpy.zeros(shape=shape, dtype='float32')
labels_siamese = numpy.zeros(shape=shape[0], dtype='float32')
genuine = True genuine = True
for i in range(shape[0]): for i in range(self.shape[0]):
data[i, ...], data_p[i, ...] = self.get_genuine_or_not(target_data, target_labels, genuine=genuine) data[i, ...], data_p[i, ...] = self.get_genuine_or_not(self.data, self.labels, genuine=genuine)
if zero_one_labels: if zero_one_labels:
labels_siamese[i] = not genuine labels_siamese[i] = not genuine
else: else:
...@@ -107,3 +81,62 @@ class MemoryDataShuffler(BaseDataShuffler): ...@@ -107,3 +81,62 @@ class MemoryDataShuffler(BaseDataShuffler):
genuine = not genuine genuine = not genuine
return data, data_p, labels_siamese return data, data_p, labels_siamese
def get_triplet(self, n_labels, n_triplets=1, is_target_set_train=True):
"""
Get a triplet
**Parameters**
is_target_set_train: Defining the target set to get the batch
**Return**
"""
def get_one_triplet(input_data, input_labels):
# Getting a pair of clients
index = numpy.random.choice(n_labels, 2, replace=False)
label_positive = index[0]
label_negative = index[1]
# Getting the indexes of the data from a particular client
indexes = numpy.where(input_labels == index[0])[0]
numpy.random.shuffle(indexes)
# Picking a positive pair
data_anchor = input_data[indexes[0], :, :, :]
data_positive = input_data[indexes[1], :, :, :]
# Picking a negative sample
indexes = numpy.where(input_labels == index[1])[0]
numpy.random.shuffle(indexes)
data_negative = input_data[indexes[0], :, :, :]
return data_anchor, data_positive, data_negative, label_positive, label_positive, label_negative
if is_target_set_train:
target_data = self.train_data
target_labels = self.train_labels
else:
target_data = self.validation_data
target_labels = self.validation_labels
c = target_data.shape[3]
w = target_data.shape[1]
h = target_data.shape[2]
data_a = numpy.zeros(shape=(n_triplets, w, h, c), dtype='float32')
data_p = numpy.zeros(shape=(n_triplets, w, h, c), dtype='float32')
data_n = numpy.zeros(shape=(n_triplets, w, h, c), dtype='float32')
labels_a = numpy.zeros(shape=n_triplets, dtype='float32')
labels_p = numpy.zeros(shape=n_triplets, dtype='float32')
labels_n = numpy.zeros(shape=n_triplets, dtype='float32')
for i in range(n_triplets):
data_a[i, :, :, :], data_p[i, :, :, :], data_n[i, :, :, :], \
labels_a[i], labels_p[i], labels_n[i] = \
get_one_triplet(target_data, target_labels)
return data_a, data_p, data_n, labels_a, labels_p, labels_n
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST
import numpy
from .MemoryDataShuffler import MemoryDataShuffler
class MemoryPairDataShuffler(MemoryDataShuffler):
def __init__(self, data, labels, input_shape, perc_train=0.9, scale=True, train_batch_size=1, validation_batch_size=300):
"""
The class provide some functionalities for shuffling data
**Parameters**
data:
"""
data = data
labels = labels
input_shape = input_shape
perc_train = perc_train
scale = scale
train_batch_size = train_batch_size
validation_batch_size = validation_batch_size
super(MemoryPairDataShuffler, self).__init__(data, labels,
input_shape=input_shape,
perc_train=perc_train,
scale=scale,
train_batch_size=train_batch_size*2,
validation_batch_size=validation_batch_size)
def get_pair(self, train_dataset=True, zero_one_labels=True):
"""
Get a random pair of samples
**Parameters**
is_target_set_train: Defining the target set to get the batch
**Return**
"""
def get_genuine_or_not(input_data, input_labels, genuine=True):
if genuine:
# TODO: THIS KEY SELECTION NEEDS TO BE MORE EFFICIENT
# Getting a client
index = numpy.random.randint(self.total_labels)
# Getting the indexes of the data from a particular client
indexes = numpy.where(input_labels == index)[0]
numpy.random.shuffle(indexes)
# Picking a pair
data = input_data[indexes[0], ...]
data_p = input_data[indexes[1], ...]
else:
# Picking a pair from different clients
index = numpy.random.choice(self.total_labels, 2, replace=False)
# Getting the indexes of the two clients
indexes = numpy.where(input_labels == index[0])[0]
indexes_p = numpy.where(input_labels == index[1])[0]
numpy.random.shuffle(indexes)
numpy.random.shuffle(indexes_p)
# Picking a pair
data = input_data[indexes[0], ...]
data_p = input_data[indexes_p[0], ...]
return data, data_p
if train_dataset:
target_data = self.train_data
target_labels = self.train_labels
shape = self.train_shape
else:
target_data = self.validation_data
target_labels = self.validation_labels
shape = self.validation_shape
data = numpy.zeros(shape=shape, dtype='float32')
data_p = numpy.zeros(shape=shape, dtype='float32')
labels_siamese = numpy.zeros(shape=shape[0], dtype='float32')
genuine = True
for i in range(shape[0]):
data[i, ...], data_p[i, ...] = get_genuine_or_not(target_data, target_labels, genuine=genuine)
if zero_one_labels:
labels_siamese[i] = not genuine
else:
labels_siamese[i] = -1 if genuine else +1
genuine = not genuine
return data, data_p, labels_siamese
def get_triplet(self, n_labels, n_triplets=1, is_target_set_train=True):
"""
Get a triplet
**Parameters**
is_target_set_train: Defining the target set to get the batch
**Return**
"""
def get_one_triplet(input_data, input_labels):
# Getting a pair of clients
index = numpy.random.choice(n_labels, 2, replace=False)
label_positive = index[0]
label_negative = index[1]
# Getting the indexes of the data from a particular client
indexes = numpy.where(input_labels == index[0])[0]
numpy.random.shuffle(indexes)
# Picking a positive pair
data_anchor = input_data[indexes[0], :, :, :]
data_positive = input_data[indexes[1], :, :, :]
# Picking a negative sample
indexes = numpy.where(input_labels == index[1])[0]
numpy.random.shuffle(indexes)
data_negative = input_data[indexes[0], :, :, :]
return data_anchor, data_positive, data_negative, label_positive, label_positive, label_negative
if is_target_set_train:
target_data = self.train_data
target_labels = self.train_labels
else:
target_data = self.validation_data
target_labels = self.validation_labels
c = target_data.shape[3]
w = target_data.shape[1]
h = target_data.shape[2]
data_a = numpy.zeros(shape=(n_triplets, w, h, c), dtype='float32')
data_p = numpy.zeros(shape=(n_triplets, w, h, c), dtype='float32')
data_n = numpy.zeros(shape=(n_triplets, w, h, c), dtype='float32')
labels_a = numpy.zeros(shape=n_triplets, dtype='float32')
labels_p = numpy.zeros(shape=n_triplets, dtype='float32')
labels_n = numpy.zeros(shape=n_triplets, dtype='float32')
for i in range(n_triplets):
data_a[i, :, :, :], data_p[i, :, :, :], data_n[i, :, :, :], \
labels_a[i], labels_p[i], labels_n[i] = \
get_one_triplet(target_data, target_labels)
return data_a, data_p, data_n, labels_a, labels_p, labels_n
...@@ -21,10 +21,8 @@ class TextDataShuffler(BaseDataShuffler): ...@@ -21,10 +21,8 @@ class TextDataShuffler(BaseDataShuffler):
def __init__(self, data, labels, def __init__(self, data, labels,
input_shape, input_shape,
input_dtype="float64", input_dtype="float64",
perc_train=0.9,
scale=True, scale=True,
train_batch_size=1, batch_size=1):
validation_batch_size=300):
""" """
Shuffler that deal with file list Shuffler that deal with file list
...@@ -48,10 +46,8 @@ class TextDataShuffler(BaseDataShuffler): ...@@ -48,10 +46,8 @@ class TextDataShuffler(BaseDataShuffler):
labels=labels, labels=labels,
input_shape=input_shape, input_shape=input_shape,
input_dtype=input_dtype, input_dtype=input_dtype,
perc_train=perc_train,
scale=scale, scale=scale,
train_batch_size=train_batch_size, batch_size=batch_size
validation_batch_size=validation_batch_size
) )
def load_from_file(self, file_name, shape): def load_from_file(self, file_name, shape):
...@@ -64,38 +60,27 @@ class TextDataShuffler(BaseDataShuffler): ...@@ -64,38 +60,27 @@ class TextDataShuffler(BaseDataShuffler):
return data return data
def get_batch(self, train_dataset=True): def get_batch(self):
if train_dataset:
batch_size = self.train_batch_size
shape = self.train_shape
files_names = self.train_data
label = self.train_labels
else:
batch_size = self.validation_batch_size
shape = self.validation_shape
files_names = self.validation_data
label = self.validation_labels
# Shuffling samples # Shuffling samples
indexes = numpy.array(range(files_names.shape[0])) indexes = numpy.array(range(self.data.shape[0]))
numpy.random.shuffle(indexes) numpy.random.shuffle(indexes)
selected_data = numpy.zeros(shape=shape) selected_data = numpy.zeros(shape=self.shape)
for i in range(batch_size): for i in range(self.batch_size):
file_name = files_names[indexes[i]] file_name = self.data[indexes[i]]
data = self.load_from_file(file_name, shape) data = self.load_from_file(file_name, self.shape)
selected_data[i, ...] = data selected_data[i, ...] = data
if self.scale: if self.scale:
selected_data[i, ...] *= self.scale_value selected_data[i, ...] *= self.scale_value
selected_labels = label[indexes[0:batch_size]] selected_labels = self.labels[indexes[0:self.batch_size]]
return selected_data.astype("float32"), selected_labels return selected_data.astype("float32"), selected_labels
def get_pair(self, train_dataset=True, zero_one_labels=True): def get_pair(self, zero_one_labels=True):
""" """
Get a random pair of samples Get a random pair of samples
...@@ -105,24 +90,15 @@ class TextDataShuffler(BaseDataShuffler): ...@@ -105,24 +90,15 @@ class TextDataShuffler(BaseDataShuffler):
**Return** **Return**
""" """
if train_dataset: data = numpy.zeros(shape=self.shape, dtype='float32')
target_data = self.train_data data_p = numpy.zeros(shape=self.shape, dtype='float32')
target_labels = self.train_labels labels_siamese = numpy.zeros(shape=self.shape[0], dtype='float32')
shape = self.train_shape
else:
target_data = self.validation_data
target_labels = self.validation_labels
shape = self.validation_shape
data = numpy.zeros(shape=shape, dtype='float32')
data_p = numpy.zeros(shape=shape, dtype='float32')
labels_siamese = numpy.zeros(shape=shape[0], dtype='float32')
genuine = True genuine = True
for i in range(shape[0]): for i in range(self.shape[0]):
file_name, file_name_p = self.get_genuine_or_not(target_data, target_labels, genuine=genuine) file_name, file_name_p = self.get_genuine_or_not(self.data, self.labels, genuine=genuine)
data[i, ...] = self.load_from_file(str(file_name), shape) data[i, ...] = self.load_from_file(str(file_name), self.shape)
data_p[i, ...] = self.load_from_file(str(file_name_p), shape) data_p[i, ...] = self.load_from_file(str(file_name_p), self.shape)
if zero_one_labels: if zero_one_labels:
labels_siamese[i] = not genuine labels_siamese[i] = not genuine
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST
import numpy
from .TextDataShuffler import TextDataShuffler
class TextPairDataShuffler(TextDataShuffler):
def __init__(self, data, labels, input_shape, perc_train=0.9, scale=True, train_batch_size=1, validation_batch_size=300):
"""
The class provide some functionalities for shuffling data
**Parameters**
data:
"""
data = data
labels = labels
input_shape = input_shape
perc_train = perc_train
scale = scale
train_batch_size = train_batch_size
validation_batch_size = validation_batch_size
super(TextPairDataShuffler, self).__init__(data, labels,
input_shape=input_shape,
perc_train=perc_train,
scale=scale,
train_batch_size=train_batch_size*2,
validation_batch_size=validation_batch_size)
def get_pair(self, train_dataset=True, zero_one_labels=True):
"""
Get a random pair of samples
**Parameters**
is_target_set_train: Defining the target set to get the batch
**Return**
"""
def get_genuine_or_not(input_data, input_labels, genuine=True):
if genuine:
# TODO: THIS KEY SELECTION NEEDS TO BE MORE EFFICIENT
# Getting a client
index = numpy.random.randint(self.total_labels)
# Getting the indexes of the data from a particular client
indexes = numpy.where(input_labels == index)[0]
numpy.random.shuffle(indexes)
# Picking a pair
data = input_data[indexes[0]]
data_p = input_data[indexes[1]]
else:
# Picking a pair from different clients
index = numpy.random.choice(self.total_labels, 2, replace=False)
# Getting the indexes of the two clients
indexes = numpy.where(input_labels == index[0])[0]
indexes_p = numpy.where(input_labels == index[1])[0]
numpy.random.shuffle(indexes)
numpy.random.shuffle(indexes_p)
# Picking a pair
data = input_data[indexes[0]]
data_p = input_data[indexes_p[0]]
return data, data_p
if train_dataset:
target_data = self.train_data
target_labels = self.train_labels
shape = self.train_shape
else:
target_data = self.validation_data
target_labels = self.validation_labels
shape = self.validation_shape
data = numpy.zeros(shape=shape, dtype='float32')
data_p = numpy.zeros(shape=shape, dtype='float32')
labels_siamese = numpy.zeros(shape=shape[0], dtype='float32')
genuine = True
for i in range(shape[0]):
data[i, ...], data_p[i, ...] = get_genuine_or_not(target_data, target_labels, genuine=genuine)
if zero_one_labels:
labels_siamese[i] = not genuine
else:
labels_siamese[i] = -1 if genuine else +1
genuine = not genuine