Commit f7760dd6 authored by Tiago Pereira's avatar Tiago Pereira
Browse files

Optimizing triplet batching

parent 3130fbc4
Pipeline #11398 canceled with stages
......@@ -15,6 +15,8 @@ class Triplet(Base):
The first element is the batch for the anchor, the second one is the batch for the positive class, w.r.t the
anchor, and the last one is the batch for the negative class , w.r.t the anchor.
Here, an epoch is not all possible triplets. An epoch is when you pass thought all the samples at least once.
"""
......@@ -53,23 +55,48 @@ class Triplet(Base):
self.data_ph_from_queue['positive'] = self.data_ph['positive']
self.data_ph_from_queue['negative'] = self.data_ph['negative']
def get_one_triplet(self, input_data, input_labels):
def get_triplets(self, input_data, input_labels):
# Shuffling all the indexes
indexes_per_labels = dict()
for l in self.possible_labels:
indexes_per_labels[l] = numpy.where(input_labels == l)[0]
numpy.random.shuffle(indexes_per_labels[l])
# searching for random triplets
offset_class = 0
for i in range(input_data.shape[0]):
anchor = input_data[indexes_per_labels[offset_class][numpy.random.randint(len(indexes_per_labels[offset_class]))], ...]
positive = input_data[indexes_per_labels[offset_class][numpy.random.randint(len(indexes_per_labels[offset_class]))], ...]
# Changing the class
offset_class += 1
if offset_class == len(self.possible_labels):
offset_class = 0
negative = input_data[indexes_per_labels[offset_class][numpy.random.randint(len(indexes_per_labels[offset_class]))], ...]
yield anchor, positive, negative
# Getting a pair of clients
index = numpy.random.choice(len(self.possible_labels), 2, replace=False)
index[0] = self.possible_labels[index[0]]
index[1] = self.possible_labels[index[1]]
#index = numpy.random.choice(len(self.possible_labels), 2, replace=False)
#index[0] = self.possible_labels[index[0]]
#index[1] = self.possible_labels[index[1]]
# Getting the indexes of the data from a particular client
indexes = numpy.where(input_labels == index[0])[0]
numpy.random.shuffle(indexes)
#indexes = numpy.where(input_labels == index[0])[0]
#numpy.random.shuffle(indexes)
# Picking a positive pair
data_anchor = input_data[indexes[0], ...]
data_positive = input_data[indexes[1], ...]
#data_anchor = input_data[indexes[0], ...]
#data_positive = input_data[indexes[1], ...]
# Picking a negative sample
indexes = numpy.where(input_labels == index[1])[0]
numpy.random.shuffle(indexes)
data_negative = input_data[indexes[0], ...]
#indexes = numpy.where(input_labels == index[1])[0]
#numpy.random.shuffle(indexes)
#data_negative = input_data[indexes[0], ...]
return data_anchor, data_positive, data_negative
#return data_anchor, data_positive, data_negative
......@@ -74,7 +74,7 @@ class TripletMemory(Triplet, Memory):
self.data = self.data.astype(input_dtype)
def get_batch(self):
def _fetch_batch(self):
"""
Get a random triplet
......@@ -84,30 +84,35 @@ class TripletMemory(Triplet, Memory):
**Return**
"""
shape = [self.batch_size] + list(self.input_shape[1:])
#shape = [self.batch_size] + list(self.input_shape[1:])
sample_a = numpy.zeros(shape=shape, dtype=self.input_dtype)
sample_p = numpy.zeros(shape=shape, dtype=self.input_dtype)
sample_n = numpy.zeros(shape=shape, dtype=self.input_dtype)
#sample_a = numpy.zeros(shape=shape, dtype=self.input_dtype)
#sample_p = numpy.zeros(shape=shape, dtype=self.input_dtype)
#sample_n = numpy.zeros(shape=shape, dtype=self.input_dtype)
for i in range(shape[0]):
sample_a[i, ...], sample_p[i, ...], sample_n[i, ...] = self.get_one_triplet(self.data, self.labels)
#for i in range(shape[0]):
# sample_a[i, ...], sample_p[i, ...], sample_n[i, ...] = self.get_one_triplet(self.data, self.labels)
# Applying the data augmentation
if self.data_augmentation is not None:
for i in range(sample_a.shape[0]):
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(sample_a[i, ...])))
sample_a[i, ...] = d
triplets = self.get_triplets(self.data, self.labels)
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(sample_p[i, ...])))
sample_p[i, ...] = d
for i in range(self.data.shape[0]):
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(sample_n[i, ...])))
sample_n[i, ...] = d
anchor, positive, negative = triplets.next()
# Scaling
sample_a = self.normalize_sample(sample_a)
sample_p = self.normalize_sample(sample_p)
sample_n = self.normalize_sample(sample_n)
# Applying the data augmentation
if self.data_augmentation is not None:
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(anchor)))
anchor = d
return [sample_a.astype(self.input_dtype), sample_p.astype(self.input_dtype), sample_n.astype(self.input_dtype)]
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(positive)))
positive = d
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(negative)))
negative = d
# Scaling
anchor = self.normalize_sample(anchor).astype(self.input_dtype)
positive = self.normalize_sample(positive).astype(self.input_dtype)
negative = self.normalize_sample(negative).astype(self.input_dtype)
yield anchor, positive, negative
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment