Commit 57d0adb8 authored by Tiago Pereira's avatar Tiago Pereira
Browse files

Loaded triplet network from file

parent e346a1f1
......@@ -25,7 +25,7 @@ Some unit tests that create networks on the fly and load variables
batch_size = 16
validation_batch_size = 400
iterations =300
iterations = 300
seed = 10
......@@ -119,67 +119,65 @@ def test_triplet_cnn_pretrained():
# Creating datashufflers
data_augmentation = ImageAugmentation()
train_data_shuffler = TripletMemory(train_data, train_labels,
input_shape=[28, 28, 1],
input_shape=[None, 28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation)
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
validation_data_shuffler = TripletMemory(validation_data, validation_labels,
input_shape=[28, 28, 1],
input_shape=[None, 28, 28, 1],
batch_size=validation_batch_size)
directory = "./temp/cnn"
directory2 = "./temp/cnn2"
# Creating a random network
scratch = scratch_network()
input_pl = train_data_shuffler("data", from_queue=False)
graph = dict()
graph['anchor'] = scratch_network(input_pl['anchor'])
graph['positive'] = scratch_network(input_pl['positive'])
graph['negative'] = scratch_network(input_pl['negative'])
# Loss for the softmax
loss = TripletLoss(margin=4.)
# One graph trainer
trainer = TripletTrainer(architecture=scratch,
loss=loss,
trainer = TripletTrainer(train_data_shuffler,
iterations=iterations,
analizer=None,
prefetch=False,
learning_rate=constant(0.05, name="regular_lr"),
optimizer=tf.train.AdamOptimizer(name="adam_pretrained_model"),
temp_dir=directory
)
trainer.train(train_data_shuffler)
temp_dir=directory)
trainer.create_network_from_scratch(graph=graph,
loss=loss,
learning_rate=constant(0.01, name="regular_lr"),
optimizer=tf.train.GradientDescentOptimizer(0.01))
trainer.train()
# Testing
eer = dummy_experiment(validation_data_shuffler, scratch)
embedding = Embedding(trainer.data_ph['anchor'], trainer.graph['anchor'])
eer = dummy_experiment(validation_data_shuffler, embedding)
# The result is not so good
assert eer < 0.25
del scratch
del graph
del loss
del trainer
# Training the network using a pre trained model
loss = TripletLoss(margin=4.)
scratch = scratch_network()
trainer = TripletTrainer(architecture=scratch,
loss=loss,
iterations=iterations + 200,
trainer = TripletTrainer(train_data_shuffler,
iterations=iterations*2,
analizer=None,
prefetch=False,
learning_rate=None,
temp_dir=directory2,
model_from_file=os.path.join(directory, "model.ckp")
)
temp_dir=directory)
trainer.create_network_from_file(os.path.join(directory, "model.ckp"))
trainer.train()
trainer.train(train_data_shuffler)
embedding = Embedding(trainer.data_ph['anchor'], trainer.graph['anchor'])
eer = dummy_experiment(validation_data_shuffler, embedding)
eer = dummy_experiment(validation_data_shuffler, scratch)
# Now it is better
assert eer < 0.15
assert eer < 0.20
shutil.rmtree(directory)
shutil.rmtree(directory2)
del scratch
del loss
del trainer
......@@ -200,9 +198,7 @@ def test_siamese_cnn_pretrained():
input_shape=[None, 28, 28, 1],
batch_size=validation_batch_size,
normalizer=ScaleFactor())
directory = "./temp/cnn"
directory2 = "./temp/cnn2"
# Creating graph
input_pl = train_data_shuffler("data")
......@@ -225,7 +221,8 @@ def test_siamese_cnn_pretrained():
trainer.train()
# Testing
embedding = Embedding(train_data_shuffler("data", from_queue=False)['left'], graph['left'])
#embedding = Embedding(train_data_shuffler("data", from_queue=False)['left'], graph['left'])
embedding = Embedding(trainer.data_ph['left'], trainer.graph['left'])
eer = dummy_experiment(validation_data_shuffler, embedding)
assert eer < 0.10
......@@ -241,14 +238,11 @@ def test_siamese_cnn_pretrained():
trainer.create_network_from_file(os.path.join(directory, "model.ckp"))
trainer.train()
#import ipdb; ipdb.set_trace()
embedding = Embedding(train_data_shuffler("data", from_queue=False)['left'], trainer.graph['left'])
#embedding = Embedding(train_data_shuffler("data", from_queue=False)['left'], trainer.graph['left'])
embedding = Embedding(trainer.data_ph['left'], trainer.graph['left'])
eer = dummy_experiment(validation_data_shuffler, embedding)
assert eer < 0.10
shutil.rmtree(directory)
shutil.rmtree(directory2)
del graph
del loss
del trainer
......@@ -192,6 +192,45 @@ class TripletTrainer(Trainer):
# Creating the variables
tf.global_variables_initializer().run(session=self.session)
def create_network_from_file(self, model_from_file):
"""
Bootstrap all the necessary data from file
** Parameters **
session: Tensorflow session
train_data_shuffler: Data shuffler for training
validation_data_shuffler: Data shuffler for validation
"""
#saver = self.architecture.load(self.model_from_file, clear_devices=False)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta")
self.saver.restore(self.session, model_from_file)
# Loading the graph from the graph pointers
self.graph = dict()
self.graph['anchor'] = tf.get_collection("graph_anchor")[0]
self.graph['positive'] = tf.get_collection("graph_positive")[0]
self.graph['negative'] = tf.get_collection("graph_negative")[0]
# Loading the placeholders from the pointers
self.data_ph = dict()
self.data_ph['anchor'] = tf.get_collection("data_ph_anchor")[0]
self.data_ph['positive'] = tf.get_collection("data_ph_positive")[0]
self.data_ph['negative'] = tf.get_collection("data_ph_negative")[0]
# Loading loss from the pointers
self.predictor = dict()
self.predictor['loss'] = tf.get_collection("predictor_loss")[0]
self.predictor['between_class'] = tf.get_collection("predictor_between_class_loss")[0]
self.predictor['within_class'] = tf.get_collection("predictor_within_class_loss")[0]
# Loading other elements
self.optimizer = tf.get_collection("optimizer")[0]
self.learning_rate = tf.get_collection("learning_rate")[0]
self.summaries_train = tf.get_collection("summaries_train")[0]
self.global_step = tf.get_collection("global_step")[0]
self.from_scratch = False
def get_feed_dict(self, data_shuffler):
"""
Given a data shuffler prepared the dictionary to be injected in the graph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment