Skip to content
Snippets Groups Projects
Commit 57d0adb8 authored by Tiago Pereira's avatar Tiago Pereira
Browse files

Loaded triplet network from file

parent e346a1f1
No related branches found
No related tags found
No related merge requests found
...@@ -25,7 +25,7 @@ Some unit tests that create networks on the fly and load variables ...@@ -25,7 +25,7 @@ Some unit tests that create networks on the fly and load variables
batch_size = 16 batch_size = 16
validation_batch_size = 400 validation_batch_size = 400
iterations =300 iterations = 300
seed = 10 seed = 10
...@@ -119,67 +119,65 @@ def test_triplet_cnn_pretrained(): ...@@ -119,67 +119,65 @@ def test_triplet_cnn_pretrained():
# Creating datashufflers # Creating datashufflers
data_augmentation = ImageAugmentation() data_augmentation = ImageAugmentation()
train_data_shuffler = TripletMemory(train_data, train_labels, train_data_shuffler = TripletMemory(train_data, train_labels,
input_shape=[28, 28, 1], input_shape=[None, 28, 28, 1],
batch_size=batch_size, batch_size=batch_size,
data_augmentation=data_augmentation) data_augmentation=data_augmentation)
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1)) validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
validation_data_shuffler = TripletMemory(validation_data, validation_labels, validation_data_shuffler = TripletMemory(validation_data, validation_labels,
input_shape=[28, 28, 1], input_shape=[None, 28, 28, 1],
batch_size=validation_batch_size) batch_size=validation_batch_size)
directory = "./temp/cnn" directory = "./temp/cnn"
directory2 = "./temp/cnn2"
# Creating a random network # Creating a random network
scratch = scratch_network() input_pl = train_data_shuffler("data", from_queue=False)
graph = dict()
graph['anchor'] = scratch_network(input_pl['anchor'])
graph['positive'] = scratch_network(input_pl['positive'])
graph['negative'] = scratch_network(input_pl['negative'])
# Loss for the softmax # Loss for the softmax
loss = TripletLoss(margin=4.) loss = TripletLoss(margin=4.)
# One graph trainer # One graph trainer
trainer = TripletTrainer(architecture=scratch, trainer = TripletTrainer(train_data_shuffler,
loss=loss,
iterations=iterations, iterations=iterations,
analizer=None, analizer=None,
prefetch=False, temp_dir=directory)
learning_rate=constant(0.05, name="regular_lr"),
optimizer=tf.train.AdamOptimizer(name="adam_pretrained_model"), trainer.create_network_from_scratch(graph=graph,
temp_dir=directory loss=loss,
) learning_rate=constant(0.01, name="regular_lr"),
trainer.train(train_data_shuffler) optimizer=tf.train.GradientDescentOptimizer(0.01))
trainer.train()
# Testing # Testing
eer = dummy_experiment(validation_data_shuffler, scratch) embedding = Embedding(trainer.data_ph['anchor'], trainer.graph['anchor'])
eer = dummy_experiment(validation_data_shuffler, embedding)
# The result is not so good # The result is not so good
assert eer < 0.25 assert eer < 0.25
del scratch del graph
del loss del loss
del trainer del trainer
# Training the network using a pre trained model # Training the network using a pre trained model
loss = TripletLoss(margin=4.) trainer = TripletTrainer(train_data_shuffler,
scratch = scratch_network() iterations=iterations*2,
trainer = TripletTrainer(architecture=scratch,
loss=loss,
iterations=iterations + 200,
analizer=None, analizer=None,
prefetch=False, temp_dir=directory)
learning_rate=None,
temp_dir=directory2, trainer.create_network_from_file(os.path.join(directory, "model.ckp"))
model_from_file=os.path.join(directory, "model.ckp") trainer.train()
)
trainer.train(train_data_shuffler) embedding = Embedding(trainer.data_ph['anchor'], trainer.graph['anchor'])
eer = dummy_experiment(validation_data_shuffler, embedding)
eer = dummy_experiment(validation_data_shuffler, scratch)
# Now it is better # Now it is better
assert eer < 0.15 assert eer < 0.20
shutil.rmtree(directory) shutil.rmtree(directory)
shutil.rmtree(directory2)
del scratch
del loss
del trainer del trainer
...@@ -200,9 +198,7 @@ def test_siamese_cnn_pretrained(): ...@@ -200,9 +198,7 @@ def test_siamese_cnn_pretrained():
input_shape=[None, 28, 28, 1], input_shape=[None, 28, 28, 1],
batch_size=validation_batch_size, batch_size=validation_batch_size,
normalizer=ScaleFactor()) normalizer=ScaleFactor())
directory = "./temp/cnn" directory = "./temp/cnn"
directory2 = "./temp/cnn2"
# Creating graph # Creating graph
input_pl = train_data_shuffler("data") input_pl = train_data_shuffler("data")
...@@ -225,7 +221,8 @@ def test_siamese_cnn_pretrained(): ...@@ -225,7 +221,8 @@ def test_siamese_cnn_pretrained():
trainer.train() trainer.train()
# Testing # Testing
embedding = Embedding(train_data_shuffler("data", from_queue=False)['left'], graph['left']) #embedding = Embedding(train_data_shuffler("data", from_queue=False)['left'], graph['left'])
embedding = Embedding(trainer.data_ph['left'], trainer.graph['left'])
eer = dummy_experiment(validation_data_shuffler, embedding) eer = dummy_experiment(validation_data_shuffler, embedding)
assert eer < 0.10 assert eer < 0.10
...@@ -241,14 +238,11 @@ def test_siamese_cnn_pretrained(): ...@@ -241,14 +238,11 @@ def test_siamese_cnn_pretrained():
trainer.create_network_from_file(os.path.join(directory, "model.ckp")) trainer.create_network_from_file(os.path.join(directory, "model.ckp"))
trainer.train() trainer.train()
#import ipdb; ipdb.set_trace() #embedding = Embedding(train_data_shuffler("data", from_queue=False)['left'], trainer.graph['left'])
embedding = Embedding(train_data_shuffler("data", from_queue=False)['left'], trainer.graph['left']) embedding = Embedding(trainer.data_ph['left'], trainer.graph['left'])
eer = dummy_experiment(validation_data_shuffler, embedding) eer = dummy_experiment(validation_data_shuffler, embedding)
assert eer < 0.10 assert eer < 0.10
shutil.rmtree(directory) shutil.rmtree(directory)
shutil.rmtree(directory2)
del graph
del loss
del trainer del trainer
...@@ -192,6 +192,45 @@ class TripletTrainer(Trainer): ...@@ -192,6 +192,45 @@ class TripletTrainer(Trainer):
# Creating the variables # Creating the variables
tf.global_variables_initializer().run(session=self.session) tf.global_variables_initializer().run(session=self.session)
def create_network_from_file(self, model_from_file):
"""
Bootstrap all the necessary data from file
** Parameters **
session: Tensorflow session
train_data_shuffler: Data shuffler for training
validation_data_shuffler: Data shuffler for validation
"""
#saver = self.architecture.load(self.model_from_file, clear_devices=False)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta")
self.saver.restore(self.session, model_from_file)
# Loading the graph from the graph pointers
self.graph = dict()
self.graph['anchor'] = tf.get_collection("graph_anchor")[0]
self.graph['positive'] = tf.get_collection("graph_positive")[0]
self.graph['negative'] = tf.get_collection("graph_negative")[0]
# Loading the placeholders from the pointers
self.data_ph = dict()
self.data_ph['anchor'] = tf.get_collection("data_ph_anchor")[0]
self.data_ph['positive'] = tf.get_collection("data_ph_positive")[0]
self.data_ph['negative'] = tf.get_collection("data_ph_negative")[0]
# Loading loss from the pointers
self.predictor = dict()
self.predictor['loss'] = tf.get_collection("predictor_loss")[0]
self.predictor['between_class'] = tf.get_collection("predictor_between_class_loss")[0]
self.predictor['within_class'] = tf.get_collection("predictor_within_class_loss")[0]
# Loading other elements
self.optimizer = tf.get_collection("optimizer")[0]
self.learning_rate = tf.get_collection("learning_rate")[0]
self.summaries_train = tf.get_collection("summaries_train")[0]
self.global_step = tf.get_collection("global_step")[0]
self.from_scratch = False
def get_feed_dict(self, data_shuffler): def get_feed_dict(self, data_shuffler):
""" """
Given a data shuffler prepared the dictionary to be injected in the graph Given a data shuffler prepared the dictionary to be injected in the graph
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment