Skip to content
Snippets Groups Projects
Commit 5873254e authored by Tiago Pereira's avatar Tiago Pereira
Browse files

Siamese nets with generators

parent 617bd939
Branches
Tags
No related merge requests found
Pipeline #
......@@ -75,8 +75,6 @@ class Memory(Base):
indexes = numpy.array(range(self.data.shape[0]))
numpy.random.shuffle(indexes)
#import ipdb; ipdb.set_trace();
for i in range(len(indexes)):
sample = self.data[indexes[i], ...]
......@@ -91,32 +89,3 @@ class Memory(Base):
sample = self.normalize_sample(sample)
yield [sample, label]
def get_batch(self):
"""
Shuffle the Memory dataset and get a random batch.
** Returns **
data:
Selected samples
labels:
Correspondent labels
"""
if self.generator is None:
self.generator = self._fetch_batch()
holder = []
try:
for i in range(self.batch_size):
data = self.generator.next()
holder.append(data)
if len(holder) == self.batch_size:
return self._aggregate_batch(holder, False)
except StopIteration:
self.generator = None
self.epoch += 1
return self._aggregate_batch(holder, False)
......@@ -14,7 +14,10 @@ class Siamese(Base):
Basically the py:meth:`get_batch` method provides you 3 elements in the returned list.
The first two are the batch data, and the last is the label. Either `0` for samples from the same class or `1`
for samples from different class.
for samples from different class.
Here, an epoch is not all possible pairs. An epoch is when you pass thought all the samples at least once.
"""
......@@ -52,7 +55,6 @@ class Siamese(Base):
self.data_ph_from_queue['right'] = self.data_ph['right']
self.label_ph_from_queue = self.label_ph
def get_genuine_or_not(self, input_data, input_labels, genuine=True):
if genuine:
......
......@@ -73,7 +73,7 @@ class SiameseMemory(Siamese, Memory):
numpy.random.seed(seed)
self.data = self.data.astype(input_dtype)
def get_batch(self, zero_one_labels=True):
def _fetch_batch(self, zero_one_labels=True):
"""
Get a random pair of samples
......@@ -83,31 +83,33 @@ class SiameseMemory(Siamese, Memory):
**Return**
"""
shape = [self.batch_size] + list(self.input_shape[1:])
#shape = [self.batch_size] + list(self.input_shape[1:])
sample_l = numpy.zeros(shape=shape, dtype=self.input_dtype)
sample_r = numpy.zeros(shape=shape, dtype=self.input_dtype)
labels_siamese = numpy.zeros(shape=shape[0], dtype=self.input_dtype)
#sample_l = numpy.zeros(shape=shape, dtype=self.input_dtype)
#sample_r = numpy.zeros(shape=shape, dtype=self.input_dtype)
#labels_siamese = numpy.zeros(shape=shape[0], dtype=self.input_dtype)
genuine = True
for i in range(shape[0]):
sample_l[i, ...], sample_r[i, ...] = self.get_genuine_or_not(self.data, self.labels, genuine=genuine)
for i in range(self.data.shape[0]):
left, right = self.get_genuine_or_not(self.data, self.labels, genuine=genuine)
if zero_one_labels:
labels_siamese[i] = not genuine
label = not genuine
else:
labels_siamese[i] = -1 if genuine else +1
label = -1 if genuine else +1
genuine = not genuine
# Applying the data augmentation
if self.data_augmentation is not None:
for i in range(sample_l.shape[0]):
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(sample_l[i, ...])))
sample_l[i, ...] = d
# Applying the data augmentation
if self.data_augmentation is not None:
for i in range(left.shape[0]):
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(left)))
left = d
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(sample_r[i, ...])))
sample_r[i, ...] = d
d = self.bob2skimage(self.data_augmentation(self.skimage2bob(right)))
right = d
sample_l = self.normalize_sample(sample_l)
sample_r = self.normalize_sample(sample_r)
left = self.normalize_sample(left)
right = self.normalize_sample(right)
return [sample_l.astype(self.input_dtype), sample_r.astype(self.input_dtype), labels_siamese]
yield left.astype(self.input_dtype), right.astype(self.input_dtype), label
#return [sample_l.astype(self.input_dtype), sample_r.astype(self.input_dtype), labels_siamese]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment