Commit c2d24975 authored by Tiago Pereira's avatar Tiago Pereira
Browse files

Optmizing siamease batching

parent 5873254e
Pipeline #11393 canceled with stages
......@@ -55,35 +55,38 @@ class Siamese(Base):
self.data_ph_from_queue['right'] = self.data_ph['right']
self.label_ph_from_queue = self.label_ph
def get_genuine_or_not(self, input_data, input_labels, genuine=True):
if genuine:
# Getting a client
index = numpy.random.randint(len(self.possible_labels))
index = int(self.possible_labels[index])
# Getting the indexes of the data from a particular client
indexes = numpy.where(input_labels == index)[0]
numpy.random.shuffle(indexes)
# Picking a pair
sample_l = input_data[indexes[0], ...]
sample_r = input_data[indexes[1], ...]
else:
# Picking a pair of labels from different clients
index = numpy.random.choice(len(self.possible_labels), 2, replace=False)
index[0] = self.possible_labels[int(index[0])]
index[1] = self.possible_labels[int(index[1])]
# Getting the indexes of the two clients
indexes_l = numpy.where(input_labels == index[0])[0]
indexes_r = numpy.where(input_labels == index[1])[0]
numpy.random.shuffle(indexes_l)
numpy.random.shuffle(indexes_r)
# Picking a pair
sample_l = input_data[indexes_l[0], ...]
sample_r = input_data[indexes_r[0], ...]
return sample_l, sample_r
def get_genuine_or_not(self, input_data, input_labels):
"""
Creates a generator with pairs of genuines and and impostors pairs
"""
# Shuffling all the indexes
indexes_per_labels = dict()
for l in self.possible_labels:
indexes_per_labels[l] = numpy.where(input_labels == l)[0]
numpy.random.shuffle(indexes_per_labels[l])
genuine = True
for i in range(input_data.shape[0]):
if genuine:
# Selecting the class
class_index = numpy.random.randint(len(self.possible_labels))
# Now selecting the samples for the pair
left = input_data[indexes_per_labels[class_index][numpy.random.randint(len(indexes_per_labels[class_index]))]]
right = input_data[indexes_per_labels[class_index][numpy.random.randint(len(indexes_per_labels[class_index]))]]
yield left, right, 0
else:
# Selecting the 2 classes
class_index = numpy.random.choice(len(self.possible_labels), 2, replace=False)
# Now selecting the samples for the pair
left = input_data[indexes_per_labels[class_index[0]][numpy.random.randint(len(indexes_per_labels[class_index[0]]))]]
right = input_data[indexes_per_labels[class_index[1]][numpy.random.randint(len(indexes_per_labels[class_index[1]]))]]
yield left, right, 1
genuine = not genuine
......@@ -83,29 +83,25 @@ class SiameseMemory(Siamese, Memory):
**Return**
"""
#shape = [self.batch_size] + list(self.input_shape[1:])
#sample_l = numpy.zeros(shape=shape, dtype=self.input_dtype)
#sample_r = numpy.zeros(shape=shape, dtype=self.input_dtype)
#labels_siamese = numpy.zeros(shape=shape[0], dtype=self.input_dtype)
genuine = True
#genuine = True
pairs_generator = self.get_genuine_or_not(self.data, self.labels)
for i in range(self.data.shape[0]):
left, right = self.get_genuine_or_not(self.data, self.labels, genuine=genuine)
if zero_one_labels:
label = not genuine
else:
label = -1 if genuine else +1
genuine = not genuine
#left, right = self.get_genuine_or_not(self.data, self.labels, genuine=genuine)
#if zero_one_labels:
# label = not genuine
#else:
# label = -1 if genuine else +1
#genuine = not genuine
left, right, label = pairs_generator.next()
# Applying the data augmentation
if self.data_augmentation is not None:
for i in range(left.shape[0]):
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(left)))
left = d
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(left)))
left = d
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(right)))
right = d
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(right)))
right = d
left = self.normalize_sample(left)
right = self.normalize_sample(right)
......
......@@ -66,6 +66,7 @@ def test_siamesememory_shuffler():
batch_size=16)
batch = data_shuffler.get_batch()
assert len(batch) == 3
assert batch[0].shape == (16, 28, 28, 1)
assert batch[1].shape == (16, 28, 28, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment