Skip to content
Snippets Groups Projects
Commit 2ab6b68d authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed issue #21

parent 5f82ac33
Branches
Tags
No related merge requests found
Pipeline #
...@@ -23,6 +23,12 @@ class Triplet(Base): ...@@ -23,6 +23,12 @@ class Triplet(Base):
self.data2_placeholder = None self.data2_placeholder = None
self.data3_placeholder = None self.data3_placeholder = None
def set_placeholders(self, data, data2, data3):
self.data_placeholder = data
self.data2_placeholder = data2
self.data3_placeholder = data3
def get_placeholders(self, name=""): def get_placeholders(self, name=""):
""" """
Returns a place holder with the size of your batch Returns a place holder with the size of your batch
......
...@@ -6,12 +6,13 @@ ...@@ -6,12 +6,13 @@
import numpy import numpy
import bob.io.base import bob.io.base
import os import os
from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, TripletMemory, SiameseMemory
from bob.learn.tensorflow.loss import BaseLoss from bob.learn.tensorflow.loss import BaseLoss, TripletLoss, ContrastiveLoss
from bob.learn.tensorflow.trainers import Trainer, constant from bob.learn.tensorflow.trainers import Trainer, constant, TripletTrainer, SiameseTrainer
from bob.learn.tensorflow.utils import load_mnist from bob.learn.tensorflow.utils import load_mnist
from bob.learn.tensorflow.network import SequenceNetwork from bob.learn.tensorflow.network import SequenceNetwork
from bob.learn.tensorflow.layers import Conv2D, FullyConnected from bob.learn.tensorflow.layers import Conv2D, FullyConnected
from test_cnn import dummy_experiment
import tensorflow as tf import tensorflow as tf
import shutil import shutil
...@@ -99,7 +100,7 @@ def test_cnn_pretrained(): ...@@ -99,7 +100,7 @@ def test_cnn_pretrained():
scratch = scratch_network() scratch = scratch_network()
trainer = Trainer(architecture=scratch, trainer = Trainer(architecture=scratch,
loss=loss, loss=loss,
iterations=iterations+200, iterations=iterations + 200,
analizer=None, analizer=None,
prefetch=False, prefetch=False,
learning_rate=None, learning_rate=None,
...@@ -118,3 +119,144 @@ def test_cnn_pretrained(): ...@@ -118,3 +119,144 @@ def test_cnn_pretrained():
del loss del loss
del trainer del trainer
def test_triplet_cnn_pretrained():
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
# Creating datashufflers
data_augmentation = ImageAugmentation()
train_data_shuffler = TripletMemory(train_data, train_labels,
input_shape=[28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation)
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
validation_data_shuffler = TripletMemory(validation_data, validation_labels,
input_shape=[28, 28, 1],
batch_size=validation_batch_size)
directory = "./temp/cnn"
directory2 = "./temp/cnn2"
# Creating a random network
scratch = scratch_network()
# Loss for the softmax
loss = TripletLoss(margin=4.)
# One graph trainer
trainer = TripletTrainer(architecture=scratch,
loss=loss,
iterations=iterations,
analizer=None,
prefetch=False,
learning_rate=constant(0.05, name="regular_lr"),
optimizer=tf.train.AdamOptimizer(name="adam_pretrained_model"),
temp_dir=directory
)
trainer.train(train_data_shuffler)
# Testing
eer = dummy_experiment(validation_data_shuffler, scratch)
# The result is not so good
assert eer < 0.25
del scratch
del loss
del trainer
# Training the network using a pre trained model
loss = TripletLoss(margin=4.)
scratch = scratch_network()
trainer = TripletTrainer(architecture=scratch,
loss=loss,
iterations=iterations + 200,
analizer=None,
prefetch=False,
learning_rate=None,
temp_dir=directory2,
model_from_file=os.path.join(directory, "model.ckp")
)
trainer.train(train_data_shuffler)
eer = dummy_experiment(validation_data_shuffler, scratch)
# Now it is better
assert eer < 0.15
shutil.rmtree(directory)
shutil.rmtree(directory2)
del scratch
del loss
del trainer
def test_siamese_cnn_pretrained():
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
# Creating datashufflers
data_augmentation = ImageAugmentation()
train_data_shuffler = SiameseMemory(train_data, train_labels,
input_shape=[28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation)
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
validation_data_shuffler = SiameseMemory(validation_data, validation_labels,
input_shape=[28, 28, 1],
batch_size=validation_batch_size)
directory = "./temp/cnn"
directory2 = "./temp/cnn2"
# Creating a random network
scratch = scratch_network()
# Loss for the softmax
loss = ContrastiveLoss(contrastive_margin=4.)
# One graph trainer
trainer = SiameseTrainer(architecture=scratch,
loss=loss,
iterations=iterations,
analizer=None,
prefetch=False,
learning_rate=constant(0.05, name="regular_lr"),
optimizer=tf.train.AdamOptimizer(name="adam_pretrained_model"),
temp_dir=directory
)
trainer.train(train_data_shuffler)
# Testing
eer = dummy_experiment(validation_data_shuffler, scratch)
# The result is not so good
assert eer < 0.28
del scratch
del loss
del trainer
# Training the network using a pre trained model
loss = ContrastiveLoss(contrastive_margin=4.)
scratch = scratch_network()
trainer = SiameseTrainer(architecture=scratch,
loss=loss,
iterations=iterations + 1000,
analizer=None,
prefetch=False,
learning_rate=None,
temp_dir=directory2,
model_from_file=os.path.join(directory, "model.ckp")
)
trainer.train(train_data_shuffler)
eer = dummy_experiment(validation_data_shuffler, scratch)
# Now it is better
assert eer < 0.25
shutil.rmtree(directory)
shutil.rmtree(directory2)
del scratch
del loss
del trainer
...@@ -151,6 +151,46 @@ class SiameseTrainer(Trainer): ...@@ -151,6 +151,46 @@ class SiameseTrainer(Trainer):
tf.get_collection("validation_placeholder_data2")[0], tf.get_collection("validation_placeholder_data2")[0],
tf.get_collection("validation_placeholder_label")[0]) tf.get_collection("validation_placeholder_label")[0])
def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
"""
Create all the necessary graphs for training, validation and inference graphs
"""
super(SiameseTrainer, self).bootstrap_graphs(train_data_shuffler, validation_data_shuffler)
# Triplet specific
tf.add_to_collection("between_class_graph_train", self.between_class_graph_train)
tf.add_to_collection("within_class_graph_train", self.within_class_graph_train)
# Creating validation graph
if validation_data_shuffler is not None:
tf.add_to_collection("between_class_graph_validation", self.between_class_graph_validation)
tf.add_to_collection("within_class_graph_validation", self.within_class_graph_validation)
self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)
def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler):
"""
Bootstrap all the necessary data from file
** Parameters **
session: Tensorflow session
train_data_shuffler: Data shuffler for training
validation_data_shuffler: Data shuffler for validation
"""
saver = super(SiameseTrainer, self).bootstrap_graphs_fromfile(train_data_shuffler, validation_data_shuffler)
self.between_class_graph_train = tf.get_collection("between_class_graph_train")[0]
self.within_class_graph_train = tf.get_collection("within_class_graph_train")[0]
if validation_data_shuffler is not None:
self.between_class_graph_validation = tf.get_collection("between_class_graph_validation")[0]
self.within_class_graph_validation = tf.get_collection("within_class_graph_validation")[0]
self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)
return saver
def compute_graph(self, data_shuffler, prefetch=False, name="", training=True): def compute_graph(self, data_shuffler, prefetch=False, name="", training=True):
""" """
Computes the graph for the trainer. Computes the graph for the trainer.
......
...@@ -118,6 +118,46 @@ class TripletTrainer(Trainer): ...@@ -118,6 +118,46 @@ class TripletTrainer(Trainer):
self.between_class_graph_validation = None self.between_class_graph_validation = None
self.within_class_graph_validation = None self.within_class_graph_validation = None
def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
"""
Create all the necessary graphs for training, validation and inference graphs
"""
super(TripletTrainer, self).bootstrap_graphs(train_data_shuffler, validation_data_shuffler)
# Triplet specific
tf.add_to_collection("between_class_graph_train", self.between_class_graph_train)
tf.add_to_collection("within_class_graph_train", self.within_class_graph_train)
# Creating validation graph
if validation_data_shuffler is not None:
tf.add_to_collection("between_class_graph_validation", self.between_class_graph_validation)
tf.add_to_collection("within_class_graph_validation", self.within_class_graph_validation)
self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)
def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler):
"""
Bootstrap all the necessary data from file
** Parameters **
session: Tensorflow session
train_data_shuffler: Data shuffler for training
validation_data_shuffler: Data shuffler for validation
"""
saver = super(TripletTrainer, self).bootstrap_graphs_fromfile(train_data_shuffler, validation_data_shuffler)
self.between_class_graph_train = tf.get_collection("between_class_graph_train")[0]
self.within_class_graph_train = tf.get_collection("within_class_graph_train")[0]
if validation_data_shuffler is not None:
self.between_class_graph_validation = tf.get_collection("between_class_graph_validation")[0]
self.within_class_graph_validation = tf.get_collection("within_class_graph_validation")[0]
self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)
return saver
def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler): def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
""" """
Persist the placeholders Persist the placeholders
...@@ -251,6 +291,7 @@ class TripletTrainer(Trainer): ...@@ -251,6 +291,7 @@ class TripletTrainer(Trainer):
self.within_class_graph_train, self.within_class_graph_train,
self.learning_rate, self.summaries_train], self.learning_rate, self.summaries_train],
feed_dict=feed_dict) feed_dict=feed_dict)
logger.info("Loss training set step={0} = {1}".format(step, l)) logger.info("Loss training set step={0} = {1}".format(step, l))
self.train_summary_writter.add_summary(summary, step) self.train_summary_writter.add_summary(summary, step)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment