Commit 3130fbc4 authored by Tiago Pereira's avatar Tiago Pereira
Browse files

Optmize the siamease disk

parent 84ae5441
Pipeline #11397 canceled with stages
......@@ -66,10 +66,11 @@ class Siamese(Base):
indexes_per_labels[l] = numpy.where(input_labels == l)[0]
numpy.random.shuffle(indexes_per_labels[l])
left_possible_indexes = numpy.random.choice(len(self.possible_labels), input_data.shape[0], replace=True)
right_possible_indexes = numpy.random.choice(len(self.possible_labels), input_data.shape[0], replace=True)
left_possible_indexes = numpy.random.choice(self.possible_labels, input_data.shape[0], replace=True)
right_possible_indexes = numpy.random.choice(self.possible_labels, input_data.shape[0], replace=True)
genuine = True
for i in range(input_data.shape[0]):
if genuine:
......
......@@ -82,7 +82,7 @@ class SiameseDisk(Siamese, Disk):
# TODO: very bad solution to deal with bob.shape images an tf shape images
self.bob_shape = tuple([input_shape[3]] + list(input_shape[1:3]))
def get_batch(self):
def _fetch_batch(self, zero_one_labels=True):
"""
Get a random pair of samples
......@@ -91,19 +91,23 @@ class SiameseDisk(Siamese, Disk):
**Return**
"""
shape = [self.batch_size] + list(self.input_shape[1:])
sample_l = numpy.zeros(shape=shape, dtype=self.input_dtype)
sample_r = numpy.zeros(shape=shape, dtype=self.input_dtype)
labels_siamese = numpy.zeros(shape=shape[0], dtype=self.input_dtype)
pairs_generator = self.get_genuine_or_not(self.data, self.labels)
for i in range(self.data.shape[0]):
genuine = True
for i in range(shape[0]):
file_name, file_name_p = self.get_genuine_or_not(self.data, self.labels, genuine=genuine)
sample_l[i, ...] = self.normalize_sample(self.load_from_file(str(file_name)))
sample_r[i, ...] = self.normalize_sample(self.load_from_file(str(file_name_p)))
left_filename, right_filename, label = pairs_generator.next()
left = self.load_from_file(left_filename)
right = self.load_from_file(right_filename)
labels_siamese[i] = not genuine
genuine = not genuine
# Applying the data augmentation
if self.data_augmentation is not None:
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(left)))
left = d
return sample_l, sample_r, labels_siamese
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(right)))
right = d
left = self.normalize_sample(left)
right = self.normalize_sample(right)
yield left.astype(self.input_dtype), right.astype(self.input_dtype), label
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment