Fixed bug contrastive loss

parent 57d0dd6b
......@@ -57,13 +57,26 @@ class TextDataShuffler(BaseDataShuffler):
if len(d.shape) == 2:
data = numpy.zeros(shape=(d.shape[0], d.shape[1], 1))
data[:, :, 0] = d
data = self.rescale(data)
else:
data = d
data = self.rescale(data)
data = self.bob2skimage(d)
return data
def bob2skimage(self, bob_image):
"""
Convert bob color image to the skcit image
"""
skimage = numpy.zeros(shape=(bob_image.shape[1], bob_image.shape[2], 3))
skimage[:,:,0] = bob_image[0,:,:] #Copying red
skimage[:,:,1] = bob_image[1,:,:] #Copying green
skimage[:,:,2] = bob_image[2,:,:] #Copying blue
return skimage
def get_batch(self):
# Shuffling samples
......
......@@ -36,11 +36,12 @@ class ContrastiveLoss(BaseLoss):
one = tf.constant(1.0)
d = compute_euclidean_distance(left_feature, right_feature)
between_class = tf.exp(tf.mul(one - label, tf.square(d))) # (1-Y)*(d^2)
between_class = tf.mul(one - label, tf.square(d)) # (1-Y)*(d^2)
max_part = tf.square(tf.maximum(self.contrastive_margin - d, 0))
within_class = tf.mul(label, max_part) # (Y) * max((margin - d)^2, 0)
loss = 0.5 * tf.reduce_mean(within_class + between_class)
loss = 0.5 * (within_class + between_class)
return loss, tf.reduce_mean(between_class), tf.reduce_mean(within_class)
return tf.reduce_mean(loss), tf.reduce_mean(between_class), tf.reduce_mean(within_class)
#return loss, between_class, within_class, label, left_feature, right_feature, d
......@@ -112,8 +112,8 @@ def main():
# batch_size=VALIDATION_BATCH_SIZE)
# Preparing the architecture
n_classes = len(train_data_shuffler.possible_labels)
#n_classes = 200
#n_classes = len(train_data_shuffler.possible_labels)
n_classes = 50
cnn = True
if cnn:
......@@ -125,14 +125,17 @@ def main():
#architecture = LenetDropout(default_feature_layer="fc2", n_classes=n_classes, conv1_output=4, conv2_output=8, use_gpu=USE_GPU)
loss = ContrastiveLoss()
#optimizer = tf.train.GradientDescentOptimizer(0.0001)
loss = ContrastiveLoss(contrastive_margin=3.)
optimizer = tf.train.GradientDescentOptimizer(0.00001)
trainer = SiameseTrainer(architecture=architecture,
loss=loss,
iterations=ITERATIONS,
snapshot=VALIDATION_TEST,
)
optimizer=optimizer)
trainer.train(train_data_shuffler, validation_data_shuffler)
#trainer.train(train_data_shuffler)
else:
mlp = MLP(n_classes, hidden_layers=[15, 20])
......
......@@ -64,7 +64,7 @@ class SiameseTrainer(Trainer):
"""
Injecting data in the place holder queue
"""
#for i in range(self.iterations+5):
# for i in range(self.iterations+5):
while not thread_pool.should_stop():
batch_left, batch_right, labels = train_data_shuffler.get_pair()
......@@ -112,7 +112,7 @@ class SiameseTrainer(Trainer):
train_left_graph = self.architecture.compute_graph(train_left_feature_batch)
train_right_graph = self.architecture.compute_graph(train_right_label_batch)
loss_train, within_class, between_class = self.loss(train_labels_batch,
loss_train, between_class, within_class = self.loss(train_labels_batch,
train_left_graph,
train_right_graph)
......@@ -154,21 +154,25 @@ class SiameseTrainer(Trainer):
for step in range(self.iterations):
_, l, lr, summary = session.run([optimizer, loss_train, learning_rate, merged])
_, l, lr, summary = session.run(
[optimizer, loss_train, learning_rate, merged])
#_, l, lr,b,w, summary = session.run([optimizer, loss_train, learning_rate,between_class,within_class, merged])
#_, l, lr= session.run([optimizer, loss_train, learning_rate])
train_writer.add_summary(summary, step)
print str(step)
#print str(step) + " loss: {0}, bc: {1}, wc: {2}".format(l, b, w)
#print str(step) + " loss: {0}".format(l)
sys.stdout.flush()
#import ipdb; ipdb.set_trace();
if validation_data_shuffler is not None and step % self.snapshot == 0:
print str(step)
sys.stdout.flush()
summary = session.run(merged_validation)
train_writer.add_summary(summary, step)
summary = analizer()
train_writer.add_summary(summary, step)
print str(step)
sys.stdout.flush()
print("#######DONE##########")
self.architecture.save(hdf5)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment