Fixed issue #13

parent 21229e01
......@@ -36,7 +36,6 @@ class Base(object):
self.scale_value = 0.00390625
self.input_dtype = input_dtype
# TODO: Check if the bacth size is higher than the input data
self.batch_size = batch_size
......@@ -60,6 +59,10 @@ class Base(object):
self.data_augmentation = data_augmentation
self.deployment_shape = [None] + list(input_shape)
def set_placeholders(self, data, label):
self.data_placeholder = data
self.label_placeholder = label
def get_placeholders(self, name=""):
"""
Returns a place holder with the size of your batch
......@@ -156,3 +159,4 @@ class Base(object):
......@@ -21,6 +21,11 @@ class Siamese(Base):
super(Siamese, self).__init__(**kwargs)
self.data2_placeholder = None
def set_placeholders(self, data, data2, label):
self.data_placeholder = data
self.data2_placeholder = data2
self.label_placeholder = label
def get_placeholders(self, name=""):
"""
Returns a place holder with the size of your batch
......
......@@ -121,6 +121,7 @@ def test_cnn_trainer():
# At least 80% of accuracy
assert accuracy > 80.
shutil.rmtree(directory)
del chopra
def test_siamesecnn_trainer():
......@@ -166,6 +167,7 @@ def test_siamesecnn_trainer():
# At least 80% of accuracy
assert eer < 0.25
shutil.rmtree(directory)
del chopra
def test_tripletcnn_trainer():
......@@ -212,3 +214,4 @@ def test_tripletcnn_trainer():
# At least 80% of accuracy
assert eer < 0.25
shutil.rmtree(directory)
del chopra
......@@ -68,7 +68,7 @@ def test_cnn_trainer_scratch():
scratch = scratch_network()
trainer = Trainer(architecture=scratch,
loss=loss,
iterations=iterations,
iterations=iterations+200,
analizer=None,
prefetch=False,
learning_rate=constant(0.05, name="lr2"),
......@@ -76,5 +76,8 @@ def test_cnn_trainer_scratch():
model_from_file=os.path.join(directory, "model.ckp"))
trainer.train(train_data_shuffler)
accuracy = validate_network(validation_data, validation_labels, directory)
assert accuracy > 90
accuracy = validate_network(validation_data, validation_labels, directory2)
assert accuracy > 85
shutil.rmtree(directory)
shutil.rmtree(directory2)
......@@ -102,6 +102,42 @@ class SiameseTrainer(Trainer):
self.between_class_graph_validation = None
self.within_class_graph_validation = None
def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
"""
Persist the placeholders
"""
# Persisting the placeholders
if self.prefetch:
batch, batch2, label = train_data_shuffler.get_placeholders_forprefetch()
else:
batch, batch2, label = train_data_shuffler.get_placeholders()
tf.add_to_collection("train_placeholder_data", batch)
tf.add_to_collection("train_placeholder_data2", batch2)
tf.add_to_collection("train_placeholder_label", label)
# Creating validation graph
if validation_data_shuffler is not None:
batch, batch2, label = validation_data_shuffler.get_placeholders()
tf.add_to_collection("validation_placeholder_data", batch)
tf.add_to_collection("validation_placeholder_data2", batch2)
tf.add_to_collection("validation_placeholder_label", label)
def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
"""
Load placeholders from file
"""
train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
tf.get_collection("train_placeholder_data2")[0],
tf.get_collection("train_placeholder_label")[0])
if validation_data_shuffler is not None:
train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
tf.get_collection("validation_placeholder_data2")[0],
tf.get_collection("validation_placeholder_label")[0])
def compute_graph(self, data_shuffler, prefetch=False, name="", train=True):
"""
Computes the graph for the trainer.
......
......@@ -75,10 +75,7 @@ class Trainer(object):
self.loss = loss
self.temp_dir = temp_dir
#self.base_learning_rate = base_learning_rate
self.learning_rate = learning_rate
#self.weight_decay = weight_decay
#self.decay_steps = decay_steps
self.iterations = iterations
self.snapshot = snapshot
......@@ -116,12 +113,12 @@ class Trainer(object):
"""
Computes the graph for the trainer.
** Parameters **
data_shuffler: Data shuffler
prefetch:
prefetch: Uses prefetch
name: Name of the graph
training: Is it a training graph?
"""
# Defining place holders
......@@ -203,8 +200,7 @@ class Trainer(object):
if self.validation_summary_writter is None:
self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph)
summaries = []
summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
......@@ -212,7 +208,6 @@ class Trainer(object):
"""
Creates a simple tensorboard summary with the value of the loss and learning rate
"""
# Train summary
tf.scalar_summary('loss', self.training_graph, name="train")
tf.scalar_summary('lr', self.learning_rate, name="train")
......@@ -251,12 +246,9 @@ class Trainer(object):
session.run(self.enqueue_op, feed_dict=feed_dict)
def create_graphs(self, train_data_shuffler, validation_data_shuffler):
def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
"""
:param train_data_shuffler:
:param validation_data_shuffler:
:return:
Create all the necessary graphs for training, validation and inference graphs
"""
# Creating train graph
......@@ -269,11 +261,70 @@ class Trainer(object):
tf.add_to_collection("inference_placeholder", self.architecture.inference_placeholder)
tf.add_to_collection("inference_graph", self.architecture.inference_graph)
# Creating validation graph
if validation_data_shuffler is not None:
# Creating validation graph
self.validation_graph = self.compute_graph(validation_data_shuffler, name="validation", training=False)
tf.add_to_collection("validation_graph", self.validation_graph)
batch, label = validation_data_shuffler.get_placeholders()
tf.add_to_collection("validation_placeholder_data", batch)
tf.add_to_collection("validation_placeholder_label", label)
self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)
def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
"""
Persist the placeholders
"""
# Persisting the placeholders
if self.prefetch:
batch, label = train_data_shuffler.get_placeholders_forprefetch()
else:
batch, label = train_data_shuffler.get_placeholders()
tf.add_to_collection("train_placeholder_data", batch)
tf.add_to_collection("train_placeholder_label", label)
# Creating validation graph
if validation_data_shuffler is not None:
batch, label = validation_data_shuffler.get_placeholders()
tf.add_to_collection("validation_placeholder_data", batch)
tf.add_to_collection("validation_placeholder_label", label)
def bootstrap_graphs_fromfile(self, session, train_data_shuffler, validation_data_shuffler):
"""
Bootstrap all the necessary data from file
"""
saver = self.architecture.load(session, self.model_from_file)
# Loading training graph
self.training_graph = tf.get_collection("training_graph")[0]
# Loding other elements
self.optimizer = tf.get_collection("optimizer")[0]
self.learning_rate = tf.get_collection("learning_rate")[0]
self.summaries_train = tf.get_collection("summaries_train")[0]
if validation_data_shuffler is not None:
self.validation_graph = tf.get_collection("validation_graph")[0]
self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)
return saver
def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
"""
Load placeholders from file
"""
train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
tf.get_collection("train_placeholder_label")[0])
if validation_data_shuffler is not None:
train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
tf.get_collection("validation_placeholder_label")[0])
def train(self, train_data_shuffler, validation_data_shuffler=None):
"""
Train the network
......@@ -296,12 +347,10 @@ class Trainer(object):
# Loading a pretrained model
if self.model_from_file != "":
logger.info("Loading pretrained model from {0}".format(self.model_from_file))
saver = self.architecture.load(session, self.model_from_file)
self.training_graph = tf.get_collection("training_graph")[0]
self.optimizer = tf.get_collection("optimizer")[0]
self.learning_rate = tf.get_collection("learning_rate")[0]
saver = self.bootstrap_graphs_fromfile(session, train_data_shuffler, validation_data_shuffler)
else:
self.create_graphs(train_data_shuffler, validation_data_shuffler)
# Bootstraping all the graphs
self.bootstrap_graphs(train_data_shuffler, validation_data_shuffler)
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False)
......@@ -314,6 +363,7 @@ class Trainer(object):
# Train summary
self.summaries_train = self.create_general_summary()
tf.add_to_collection("summaries_train", self.summaries_train)
tf.initialize_all_variables().run()
......
......@@ -101,6 +101,42 @@ class TripletTrainer(Trainer):
self.between_class_graph_validation = None
self.within_class_graph_validation = None
def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
"""
Persist the placeholders
"""
# Persisting the placeholders
if self.prefetch:
batch, batch2, batch3 = train_data_shuffler.get_placeholders_forprefetch()
else:
batch, batch2, batch3 = train_data_shuffler.get_placeholders()
tf.add_to_collection("train_placeholder_data", batch)
tf.add_to_collection("train_placeholder_data2", batch2)
tf.add_to_collection("train_placeholder_data3", batch3)
# Creating validation graph
if validation_data_shuffler is not None:
batch, batch2, label = validation_data_shuffler.get_placeholders()
tf.add_to_collection("validation_placeholder_data", batch)
tf.add_to_collection("validation_placeholder_data2", batch2)
tf.add_to_collection("validation_placeholder_data3", batch3)
def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
"""
Load placeholders from file
"""
train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
tf.get_collection("train_placeholder_data2")[0],
tf.get_collection("train_placeholder_data3")[0])
if validation_data_shuffler is not None:
train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
tf.get_collection("validation_placeholder_data2")[0],
tf.get_collection("validation_placeholder_data3")[0])
def compute_graph(self, data_shuffler, prefetch=False, name="", train=True):
"""
Computes the graph for the trainer.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment