Commit b08455ce authored by Tiago Pereira's avatar Tiago Pereira
Browse files

Optimizing triplet batching #28

parent f7760dd6
Pipeline #11399 failed with stages
......@@ -65,11 +65,12 @@ class Triplet(Base):
# searching for random triplets
offset_class = 0
for i in range(input_data.shape[0]):
anchor = input_data[indexes_per_labels[offset_class][numpy.random.randint(len(indexes_per_labels[offset_class]))], ...]
anchor = input_data[indexes_per_labels[self.possible_labels[offset_class]][numpy.random.randint(len(indexes_per_labels[self.possible_labels[offset_class]]))], ...]
positive = input_data[indexes_per_labels[offset_class][numpy.random.randint(len(indexes_per_labels[offset_class]))], ...]
positive = input_data[indexes_per_labels[self.possible_labels[offset_class]][numpy.random.randint(len(indexes_per_labels[self.possible_labels[offset_class]]))], ...]
# Changing the class
offset_class += 1
......@@ -77,26 +78,27 @@ class Triplet(Base):
if offset_class == len(self.possible_labels):
offset_class = 0
negative = input_data[indexes_per_labels[offset_class][numpy.random.randint(len(indexes_per_labels[offset_class]))], ...]
negative = input_data[indexes_per_labels[self.possible_labels[offset_class]][numpy.random.randint(len(indexes_per_labels[self.possible_labels[offset_class]]))], ...]
yield anchor, positive, negative
def get_one_triplet(self, input_data, input_labels):
# Getting a pair of clients
#index = numpy.random.choice(len(self.possible_labels), 2, replace=False)
#index[0] = self.possible_labels[index[0]]
#index[1] = self.possible_labels[index[1]]
index = numpy.random.choice(len(self.possible_labels), 2, replace=False)
index[0] = self.possible_labels[index[0]]
index[1] = self.possible_labels[index[1]]
# Getting the indexes of the data from a particular client
#indexes = numpy.where(input_labels == index[0])[0]
#numpy.random.shuffle(indexes)
indexes = numpy.where(input_labels == index[0])[0]
numpy.random.shuffle(indexes)
# Picking a positive pair
#data_anchor = input_data[indexes[0], ...]
#data_positive = input_data[indexes[1], ...]
data_anchor = input_data[indexes[0], ...]
data_positive = input_data[indexes[1], ...]
# Picking a negative sample
#indexes = numpy.where(input_labels == index[1])[0]
#numpy.random.shuffle(indexes)
#data_negative = input_data[indexes[0], ...]
indexes = numpy.where(input_labels == index[1])[0]
numpy.random.shuffle(indexes)
data_negative = input_data[indexes[0], ...]
#return data_anchor, data_positive, data_negative
return data_anchor, data_positive, data_negative
......@@ -86,7 +86,7 @@ class TripletDisk(Triplet, Disk):
# TODO: very bad solution to deal with bob.shape images an tf shape images
self.bob_shape = tuple([input_shape[3]] + list(input_shape[1:3]))
def get_batch(self):
def _fetch_batch(self):
"""
Get a random pair of samples
......@@ -96,16 +96,30 @@ class TripletDisk(Triplet, Disk):
**Return**
"""
shape = [self.batch_size] + list(self.input_shape[1:])
triplets = self.get_triplets(self.data, self.labels)
sample_a = numpy.zeros(shape=shape, dtype=self.input_dtype)
sample_p = numpy.zeros(shape=shape, dtype=self.input_dtype)
sample_n = numpy.zeros(shape=shape, dtype=self.input_dtype)
for i in range(self.data.shape[0]):
for i in range(shape[0]):
file_name_a, file_name_p, file_name_n = self.get_one_triplet(self.data, self.labels)
sample_a[i, ...] = self.normalize_sample(self.load_from_file(str(file_name_a)))
sample_p[i, ...] = self.normalize_sample(self.load_from_file(str(file_name_p)))
sample_n[i, ...] = self.normalize_sample(self.load_from_file(str(file_name_n)))
anchor_filename, positive_filename, negative_filename = triplets.next()
return [sample_a, sample_p, sample_n]
anchor = self.load_from_file(str(anchor_filename))
positive = self.load_from_file(str(positive_filename))
negative = self.load_from_file(str(negative_filename))
# Applying the data augmentation
if self.data_augmentation is not None:
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(anchor)))
anchor = d
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(positive)))
positive = d
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(negative)))
negative = d
# Scaling
anchor = self.normalize_sample(anchor).astype(self.input_dtype)
positive = self.normalize_sample(positive).astype(self.input_dtype)
negative = self.normalize_sample(negative).astype(self.input_dtype)
yield anchor, positive, negative
......@@ -84,15 +84,6 @@ class TripletMemory(Triplet, Memory):
**Return**
"""
#shape = [self.batch_size] + list(self.input_shape[1:])
#sample_a = numpy.zeros(shape=shape, dtype=self.input_dtype)
#sample_p = numpy.zeros(shape=shape, dtype=self.input_dtype)
#sample_n = numpy.zeros(shape=shape, dtype=self.input_dtype)
#for i in range(shape[0]):
# sample_a[i, ...], sample_p[i, ...], sample_n[i, ...] = self.get_one_triplet(self.data, self.labels)
triplets = self.get_triplets(self.data, self.labels)
for i in range(self.data.shape[0]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment