Commit 2ab6b68d authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed issue #21

parent 5f82ac33
Pipeline #6176 failed with stages
in 11 minutes and 17 seconds
......@@ -23,6 +23,12 @@ class Triplet(Base):
self.data2_placeholder = None
self.data3_placeholder = None
def set_placeholders(self, data, data2, data3):
self.data_placeholder = data
self.data2_placeholder = data2
self.data3_placeholder = data3
def get_placeholders(self, name=""):
"""
Returns a place holder with the size of your batch
......
......@@ -6,12 +6,13 @@
import numpy
import bob.io.base
import os
from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation
from bob.learn.tensorflow.loss import BaseLoss
from bob.learn.tensorflow.trainers import Trainer, constant
from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, TripletMemory, SiameseMemory
from bob.learn.tensorflow.loss import BaseLoss, TripletLoss, ContrastiveLoss
from bob.learn.tensorflow.trainers import Trainer, constant, TripletTrainer, SiameseTrainer
from bob.learn.tensorflow.utils import load_mnist
from bob.learn.tensorflow.network import SequenceNetwork
from bob.learn.tensorflow.layers import Conv2D, FullyConnected
from test_cnn import dummy_experiment
import tensorflow as tf
import shutil
......@@ -99,7 +100,7 @@ def test_cnn_pretrained():
scratch = scratch_network()
trainer = Trainer(architecture=scratch,
loss=loss,
iterations=iterations+200,
iterations=iterations + 200,
analizer=None,
prefetch=False,
learning_rate=None,
......@@ -118,3 +119,144 @@ def test_cnn_pretrained():
del loss
del trainer
def test_triplet_cnn_pretrained():
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
# Creating datashufflers
data_augmentation = ImageAugmentation()
train_data_shuffler = TripletMemory(train_data, train_labels,
input_shape=[28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation)
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
validation_data_shuffler = TripletMemory(validation_data, validation_labels,
input_shape=[28, 28, 1],
batch_size=validation_batch_size)
directory = "./temp/cnn"
directory2 = "./temp/cnn2"
# Creating a random network
scratch = scratch_network()
# Loss for the softmax
loss = TripletLoss(margin=4.)
# One graph trainer
trainer = TripletTrainer(architecture=scratch,
loss=loss,
iterations=iterations,
analizer=None,
prefetch=False,
learning_rate=constant(0.05, name="regular_lr"),
optimizer=tf.train.AdamOptimizer(name="adam_pretrained_model"),
temp_dir=directory
)
trainer.train(train_data_shuffler)
# Testing
eer = dummy_experiment(validation_data_shuffler, scratch)
# The result is not so good
assert eer < 0.25
del scratch
del loss
del trainer
# Training the network using a pre trained model
loss = TripletLoss(margin=4.)
scratch = scratch_network()
trainer = TripletTrainer(architecture=scratch,
loss=loss,
iterations=iterations + 200,
analizer=None,
prefetch=False,
learning_rate=None,
temp_dir=directory2,
model_from_file=os.path.join(directory, "model.ckp")
)
trainer.train(train_data_shuffler)
eer = dummy_experiment(validation_data_shuffler, scratch)
# Now it is better
assert eer < 0.15
shutil.rmtree(directory)
shutil.rmtree(directory2)
del scratch
del loss
del trainer
def test_siamese_cnn_pretrained():
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
# Creating datashufflers
data_augmentation = ImageAugmentation()
train_data_shuffler = SiameseMemory(train_data, train_labels,
input_shape=[28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation)
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
validation_data_shuffler = SiameseMemory(validation_data, validation_labels,
input_shape=[28, 28, 1],
batch_size=validation_batch_size)
directory = "./temp/cnn"
directory2 = "./temp/cnn2"
# Creating a random network
scratch = scratch_network()
# Loss for the softmax
loss = ContrastiveLoss(contrastive_margin=4.)
# One graph trainer
trainer = SiameseTrainer(architecture=scratch,
loss=loss,
iterations=iterations,
analizer=None,
prefetch=False,
learning_rate=constant(0.05, name="regular_lr"),
optimizer=tf.train.AdamOptimizer(name="adam_pretrained_model"),
temp_dir=directory
)
trainer.train(train_data_shuffler)
# Testing
eer = dummy_experiment(validation_data_shuffler, scratch)
# The result is not so good
assert eer < 0.28
del scratch
del loss
del trainer
# Training the network using a pre trained model
loss = ContrastiveLoss(contrastive_margin=4.)
scratch = scratch_network()
trainer = SiameseTrainer(architecture=scratch,
loss=loss,
iterations=iterations + 1000,
analizer=None,
prefetch=False,
learning_rate=None,
temp_dir=directory2,
model_from_file=os.path.join(directory, "model.ckp")
)
trainer.train(train_data_shuffler)
eer = dummy_experiment(validation_data_shuffler, scratch)
# Now it is better
assert eer < 0.25
shutil.rmtree(directory)
shutil.rmtree(directory2)
del scratch
del loss
del trainer
......@@ -151,6 +151,46 @@ class SiameseTrainer(Trainer):
tf.get_collection("validation_placeholder_data2")[0],
tf.get_collection("validation_placeholder_label")[0])
def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
"""
Create all the necessary graphs for training, validation and inference graphs
"""
super(SiameseTrainer, self).bootstrap_graphs(train_data_shuffler, validation_data_shuffler)
# Triplet specific
tf.add_to_collection("between_class_graph_train", self.between_class_graph_train)
tf.add_to_collection("within_class_graph_train", self.within_class_graph_train)
# Creating validation graph
if validation_data_shuffler is not None:
tf.add_to_collection("between_class_graph_validation", self.between_class_graph_validation)
tf.add_to_collection("within_class_graph_validation", self.within_class_graph_validation)
self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)
def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler):
"""
Bootstrap all the necessary data from file
** Parameters **
session: Tensorflow session
train_data_shuffler: Data shuffler for training
validation_data_shuffler: Data shuffler for validation
"""
saver = super(SiameseTrainer, self).bootstrap_graphs_fromfile(train_data_shuffler, validation_data_shuffler)
self.between_class_graph_train = tf.get_collection("between_class_graph_train")[0]
self.within_class_graph_train = tf.get_collection("within_class_graph_train")[0]
if validation_data_shuffler is not None:
self.between_class_graph_validation = tf.get_collection("between_class_graph_validation")[0]
self.within_class_graph_validation = tf.get_collection("within_class_graph_validation")[0]
self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)
return saver
def compute_graph(self, data_shuffler, prefetch=False, name="", training=True):
"""
Computes the graph for the trainer.
......
......@@ -118,6 +118,46 @@ class TripletTrainer(Trainer):
self.between_class_graph_validation = None
self.within_class_graph_validation = None
def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
"""
Create all the necessary graphs for training, validation and inference graphs
"""
super(TripletTrainer, self).bootstrap_graphs(train_data_shuffler, validation_data_shuffler)
# Triplet specific
tf.add_to_collection("between_class_graph_train", self.between_class_graph_train)
tf.add_to_collection("within_class_graph_train", self.within_class_graph_train)
# Creating validation graph
if validation_data_shuffler is not None:
tf.add_to_collection("between_class_graph_validation", self.between_class_graph_validation)
tf.add_to_collection("within_class_graph_validation", self.within_class_graph_validation)
self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)
def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler):
"""
Bootstrap all the necessary data from file
** Parameters **
session: Tensorflow session
train_data_shuffler: Data shuffler for training
validation_data_shuffler: Data shuffler for validation
"""
saver = super(TripletTrainer, self).bootstrap_graphs_fromfile(train_data_shuffler, validation_data_shuffler)
self.between_class_graph_train = tf.get_collection("between_class_graph_train")[0]
self.within_class_graph_train = tf.get_collection("within_class_graph_train")[0]
if validation_data_shuffler is not None:
self.between_class_graph_validation = tf.get_collection("between_class_graph_validation")[0]
self.within_class_graph_validation = tf.get_collection("within_class_graph_validation")[0]
self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)
return saver
def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
"""
Persist the placeholders
......@@ -251,6 +291,7 @@ class TripletTrainer(Trainer):
self.within_class_graph_train,
self.learning_rate, self.summaries_train],
feed_dict=feed_dict)
logger.info("Loss training set step={0} = {1}".format(step, l))
self.train_summary_writter.add_summary(summary, step)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment